Merge pull request #12 from matrix-org/federation_authorization

Federation authorization
This commit is contained in:
Mark Haines 2014-11-11 16:40:50 +00:00
commit a8ceeec0fd
71 changed files with 3834 additions and 3973 deletions

View File

@ -32,7 +32,7 @@ for port in 8080 8081 8082; do
-D --pid-file "$DIR/$port.pid" \
--manhole $((port + 1000)) \
--tls-dh-params-path "demo/demo.tls.dh" \
$PARAMS
$PARAMS $SYNAPSE_PARAMS
python -m synapse.app.homeserver \
--config-path "demo/etc/$port.config" \

View File

@ -1,13 +1,13 @@
Signing JSON
============
JSON is signed by encoding the JSON object without ``signatures`` or ``meta``
JSON is signed by encoding the JSON object without ``signatures`` or ``unsigned``
keys using a canonical encoding. The JSON bytes are then signed using the
signature algorithm and the signature encoded using base64 with the padding
stripped. The resulting base64 signature is added to an object under the
*signing key identifier* which is added to the ``signatures`` object under the
name of the server signing it which is added back to the original JSON object
along with the ``meta`` object.
along with the ``unsigned`` object.
The *signing key identifier* is the concatenation of the *signing algorithm*
and a *key version*. The *signing algorithm* identifies the algorithm used to
@ -15,8 +15,8 @@ sign the JSON. The currently support value for *signing algorithm* is
``ed25519`` as implemented by NACL (http://nacl.cr.yp.to/). The *key version*
is used to distinguish between different signing keys used by the same entity.
The ``meta`` object and the ``signatures`` object are not covered by the
signature. Therefore intermediate servers can add metadata such as time stamps
The ``unsigned`` object and the ``signatures`` object are not covered by the
signature. Therefore intermediate servers can add unsigneddata such as time stamps
and additional signatures.
@ -27,7 +27,7 @@ and additional signatures.
"signing_keys": {
"ed25519:1": "XSl0kuyvrXNj6A+7/tkrB9sxSbRi08Of5uRhxOqZtEQ"
},
"meta": {
"unsigned": {
"retrieved_ts_ms": 922834800000
},
"signatures": {
@ -41,7 +41,7 @@ and additional signatures.
def sign_json(json_object, signing_key, signing_name):
signatures = json_object.pop("signatures", {})
meta = json_object.pop("meta", None)
unsigned = json_object.pop("unsigned", None)
signed = signing_key.sign(encode_canonical_json(json_object))
signature_base64 = encode_base64(signed.signature)
@ -50,8 +50,8 @@ and additional signatures.
signatures.setdefault(sigature_name, {})[key_id] = signature_base64
json_object["signatures"] = signatures
if meta is not None:
json_object["meta"] = meta
if unsigned is not None:
json_object["unsigned"] = unsigned
return json_object

View File

@ -0,0 +1,47 @@
from synapse.crypto.event_signing import *
from syutil.base64util import encode_base64
import argparse
import hashlib
import sys
import json
class dictobj(dict):
def __init__(self, *args, **kargs):
dict.__init__(self, *args, **kargs)
self.__dict__ = self
def get_dict(self):
return dict(self)
def get_full_dict(self):
return dict(self)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'),
default=sys.stdin)
args = parser.parse_args()
logging.basicConfig()
event_json = dictobj(json.load(args.input_json))
algorithms = {
"sha256": hashlib.sha256,
}
for alg_name in event_json.hashes:
if check_event_content_hash(event_json, algorithms[alg_name]):
print "PASS content hash %s" % (alg_name,)
else:
print "FAIL content hash %s" % (alg_name,)
for algorithm in algorithms.values():
name, h_bytes = compute_event_reference_hash(event_json, algorithm)
print "Reference hash %s: %s" % (name, encode_base64(h_bytes))
if __name__=="__main__":
main()

View File

@ -0,0 +1,73 @@
from syutil.crypto.jsonsign import verify_signed_json
from syutil.crypto.signing_key import (
decode_verify_key_bytes, write_signing_keys
)
from syutil.base64util import decode_base64
import urllib2
import json
import sys
import dns.resolver
import pprint
import argparse
import logging
def get_targets(server_name):
if ":" in server_name:
target, port = server_name.split(":")
yield (target, int(port))
return
try:
answers = dns.resolver.query("_matrix._tcp." + server_name, "SRV")
for srv in answers:
yield (srv.target, srv.port)
except dns.resolver.NXDOMAIN:
yield (server_name, 8480)
def get_server_keys(server_name, target, port):
url = "https://%s:%i/_matrix/key/v1" % (target, port)
keys = json.load(urllib2.urlopen(url))
verify_keys = {}
for key_id, key_base64 in keys["verify_keys"].items():
verify_key = decode_verify_key_bytes(key_id, decode_base64(key_base64))
verify_signed_json(keys, server_name, verify_key)
verify_keys[key_id] = verify_key
return verify_keys
def main():
parser = argparse.ArgumentParser()
parser.add_argument("signature_name")
parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'),
default=sys.stdin)
args = parser.parse_args()
logging.basicConfig()
server_name = args.signature_name
keys = {}
for target, port in get_targets(server_name):
try:
keys = get_server_keys(server_name, target, port)
print "Using keys from https://%s:%s/_matrix/key/v1" % (target, port)
write_signing_keys(sys.stdout, keys.values())
break
except:
logging.exception("Error talking to %s:%s", target, port)
json_to_check = json.load(args.input_json)
print "Checking JSON:"
for key_id in json_to_check["signatures"][args.signature_name]:
try:
key = keys[key_id]
verify_signed_json(json_to_check, args.signature_name, key)
print "PASS %s" % (key_id,)
except:
logging.exception("Check for key %s failed" % (key_id,))
print "FAIL %s" % (key_id,)
if __name__ == '__main__':
main()

69
scripts/hash_history.py Normal file
View File

@ -0,0 +1,69 @@
from synapse.storage.pdu import PduStore
from synapse.storage.signatures import SignatureStore
from synapse.storage._base import SQLBaseStore
from synapse.federation.units import Pdu
from synapse.crypto.event_signing import (
add_event_pdu_content_hash, compute_pdu_event_reference_hash
)
from synapse.api.events.utils import prune_pdu
from syutil.base64util import encode_base64, decode_base64
from syutil.jsonutil import encode_canonical_json
import sqlite3
import sys
class Store(object):
_get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"]
_get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"]
_get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"]
_get_pdu_origin_signatures_txn = SignatureStore.__dict__["_get_pdu_origin_signatures_txn"]
_store_pdu_content_hash_txn = SignatureStore.__dict__["_store_pdu_content_hash_txn"]
_store_pdu_reference_hash_txn = SignatureStore.__dict__["_store_pdu_reference_hash_txn"]
_store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"]
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
store = Store()
def select_pdus(cursor):
cursor.execute(
"SELECT pdu_id, origin FROM pdus ORDER BY depth ASC"
)
ids = cursor.fetchall()
pdu_tuples = store._get_pdu_tuples(cursor, ids)
pdus = [Pdu.from_pdu_tuple(p) for p in pdu_tuples]
reference_hashes = {}
for pdu in pdus:
try:
if pdu.prev_pdus:
print "PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus
for pdu_id, origin, hashes in pdu.prev_pdus:
ref_alg, ref_hsh = reference_hashes[(pdu_id, origin)]
hashes[ref_alg] = encode_base64(ref_hsh)
store._store_prev_pdu_hash_txn(cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh)
print "SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus
pdu = add_event_pdu_content_hash(pdu)
ref_alg, ref_hsh = compute_pdu_event_reference_hash(pdu)
reference_hashes[(pdu.pdu_id, pdu.origin)] = (ref_alg, ref_hsh)
store._store_pdu_reference_hash_txn(cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh)
for alg, hsh_base64 in pdu.hashes.items():
print alg, hsh_base64
store._store_pdu_content_hash_txn(cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64))
except:
print "FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus
def main():
conn = sqlite3.connect(sys.argv[1])
cursor = conn.cursor()
select_pdus(cursor)
conn.commit()
if __name__=='__main__':
main()

View File

@ -21,8 +21,10 @@ from synapse.api.constants import Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.api.events.room import (
RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent,
RoomJoinRulesEvent, RoomCreateEvent,
)
from synapse.util.logutils import log_function
from syutil.base64util import encode_base64
import logging
@ -35,8 +37,7 @@ class Auth(object):
self.hs = hs
self.store = hs.get_datastore()
@defer.inlineCallbacks
def check(self, event, snapshot, raises=False):
def check(self, event, raises=False):
""" Checks if this event is correctly authed.
Returns:
@ -47,43 +48,51 @@ class Auth(object):
"""
try:
if hasattr(event, "room_id"):
is_state = hasattr(event, "state_key")
if event.old_state_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id)
return True
if hasattr(event, "outlier") and event.outlier is True:
# TODO (erikj): Auth for outliers is done differently.
return True
if event.type == RoomCreateEvent.TYPE:
# FIXME
return True
if event.type == RoomMemberEvent.TYPE:
yield self._can_replace_state(event)
allowed = yield self.is_membership_change_allowed(event)
defer.returnValue(allowed)
return
allowed = self.is_membership_change_allowed(event)
if allowed:
logger.debug("Allowing! %s", event)
else:
logger.debug("Denying! %s", event)
return allowed
self._check_joined_room(
member=snapshot.membership_state,
user_id=snapshot.user_id,
room_id=snapshot.room_id,
)
if is_state:
# TODO (erikj): This really only should be called for *new*
# state
yield self._can_add_state(event)
yield self._can_replace_state(event)
else:
yield self._can_send_event(event)
self.check_event_sender_in_room(event)
self._can_send_event(event)
if event.type == RoomPowerLevelsEvent.TYPE:
yield self._check_power_levels(event)
self._check_power_levels(event)
if event.type == RoomRedactionEvent.TYPE:
yield self._check_redaction(event)
self._check_redaction(event)
defer.returnValue(True)
logger.debug("Allowing! %s", event)
return True
else:
raise AuthError(500, "Unknown event: %s" % event)
except AuthError as e:
logger.info("Event auth check failed on event %s with msg: %s",
event, e.msg)
logger.info(
"Event auth check failed on event %s with msg: %s",
event, e.msg
)
logger.info("Denying! %s", event)
if raises:
raise e
defer.returnValue(False)
return False
@defer.inlineCallbacks
def check_joined_room(self, room_id, user_id):
@ -98,45 +107,80 @@ class Auth(object):
pass
defer.returnValue(None)
@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
joined_hosts = yield self.store.get_joined_hosts_for_room(room_id)
defer.returnValue(host in joined_hosts)
def check_event_sender_in_room(self, event):
key = (RoomMemberEvent.TYPE, event.user_id, )
member_event = event.state_events.get(key)
return self._check_joined_room(
member_event,
event.user_id,
event.room_id
)
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % (
user_id, room_id, repr(member)
))
@defer.inlineCallbacks
@log_function
def is_membership_change_allowed(self, event):
target_user_id = event.state_key
# does this room even exist
room = yield self.store.get_room(event.room_id)
if not room:
raise AuthError(403, "Room does not exist")
# get info about the caller
try:
caller = yield self.store.get_room_member(
user_id=event.user_id,
room_id=event.room_id)
except:
caller = None
caller_in_room = caller and caller.membership == "join"
key = (RoomMemberEvent.TYPE, event.user_id, )
caller = event.old_state_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
try:
target = yield self.store.get_room_member(
user_id=target_user_id,
room_id=event.room_id)
except:
target = None
target_in_room = target and target.membership == "join"
key = (RoomMemberEvent.TYPE, target_user_id, )
target = event.old_state_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
membership = event.content["membership"]
join_rule = yield self.store.get_room_join_rule(event.room_id)
if not join_rule:
key = (RoomJoinRulesEvent.TYPE, "", )
join_rule_event = event.old_state_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE
)
else:
join_rule = JoinRules.INVITE
user_level = self._get_power_level_from_event_state(
event,
event.user_id,
)
ban_level, kick_level, redact_level = (
self._get_ops_level_from_event_state(
event
)
)
logger.debug(
"is_membership_change_allowed: %s",
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
"target_user_id": target_user_id,
"event.user_id": event.user_id,
}
)
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
@ -153,13 +197,10 @@ class Auth(object):
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif join_rule == JoinRules.PUBLIC or room.is_public:
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
if (
not caller or caller.membership not in
[Membership.INVITE, Membership.JOIN]
):
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
else:
# TODO (erikj): may_join list
@ -171,29 +212,16 @@ class Auth(object):
if not caller_in_room: # trying to leave a room you aren't joined
raise AuthError(403, "You are not in room %s." % event.room_id)
elif target_user_id != event.user_id:
user_level = yield self.store.get_power_level(
event.room_id,
event.user_id,
)
_, kick_level, _ = yield self.store.get_ops_levels(event.room_id)
if kick_level:
kick_level = int(kick_level)
else:
kick_level = 50
kick_level = 50 # FIXME (erikj): What should we do here?
if user_level < kick_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
user_level = yield self.store.get_power_level(
event.room_id,
event.user_id,
)
ban_level, _, _ = yield self.store.get_ops_levels(event.room_id)
if ban_level:
ban_level = int(ban_level)
else:
@ -204,7 +232,30 @@ class Auth(object):
else:
raise AuthError(500, "Unknown membership %s" % membership)
defer.returnValue(True)
return True
def _get_power_level_from_event_state(self, event, user_id):
key = (RoomPowerLevelsEvent.TYPE, "", )
power_level_event = event.old_state_events.get(key)
level = None
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
level = power_level_event.content.get("users_default", 0)
return level
def _get_ops_level_from_event_state(self, event):
key = (RoomPowerLevelsEvent.TYPE, "", )
power_level_event = event.old_state_events.get(key)
if power_level_event:
return (
power_level_event.content.get("ban", 50),
power_level_event.content.get("kick", 50),
power_level_event.content.get("redact", 50),
)
return None, None, None,
@defer.inlineCallbacks
def get_user_by_req(self, request):
@ -229,7 +280,7 @@ class Auth(object):
default=[""]
)[0]
if user and access_token and ip_addr:
self.store.insert_client_ip(
yield self.store.insert_client_ip(
user=user,
access_token=access_token,
device_id=user_info["device_id"],
@ -273,17 +324,81 @@ class Auth(object):
return self.store.is_server_admin(user)
@defer.inlineCallbacks
def add_auth_events(self, event):
if event.type == RoomCreateEvent.TYPE:
event.auth_events = []
return
auth_events = []
key = (RoomPowerLevelsEvent.TYPE, "", )
power_level_event = event.old_state_events.get(key)
if power_level_event:
auth_events.append(power_level_event.event_id)
key = (RoomJoinRulesEvent.TYPE, "", )
join_rule_event = event.old_state_events.get(key)
key = (RoomMemberEvent.TYPE, event.user_id, )
member_event = event.old_state_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get("join_rule")
is_public = join_rule == JoinRules.PUBLIC if join_rule else False
else:
is_public = False
if event.type == RoomMemberEvent.TYPE:
e_type = event.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event:
auth_events.append(join_rule_event.event_id)
if member_event and not is_public:
auth_events.append(member_event.event_id)
elif member_event:
if member_event.content["membership"] == Membership.JOIN:
auth_events.append(member_event.event_id)
hashes = yield self.store.get_event_reference_hashes(
auth_events
)
hashes = [
{
k: encode_base64(v) for k, v in h.items()
if k == "sha256"
}
for h in hashes
]
event.auth_events = zip(auth_events, hashes)
@log_function
def _can_send_event(self, event):
send_level = yield self.store.get_send_event_level(event.room_id)
key = (RoomPowerLevelsEvent.TYPE, "", )
send_level_event = event.old_state_events.get(key)
send_level = None
if send_level_event:
send_level = send_level_event.content.get("events", {}).get(
event.type
)
if not send_level:
if hasattr(event, "state_key"):
send_level = send_level_event.content.get(
"state_default", 50
)
else:
send_level = send_level_event.content.get(
"events_default", 0
)
if send_level:
send_level = int(send_level)
else:
send_level = 0
user_level = yield self.store.get_power_level(
event.room_id,
user_level = self._get_power_level_from_event_state(
event,
event.user_id,
)
@ -294,84 +409,22 @@ class Auth(object):
if user_level < send_level:
raise AuthError(
403, "You don't have permission to post to the room"
403,
"You don't have permission to post that to the room. " +
"user_level (%d) < send_level (%d)" % (user_level, send_level)
)
defer.returnValue(True)
return True
@defer.inlineCallbacks
def _can_add_state(self, event):
add_level = yield self.store.get_add_state_level(event.room_id)
if not add_level:
defer.returnValue(True)
add_level = int(add_level)
user_level = yield self.store.get_power_level(
event.room_id,
event.user_id,
)
user_level = int(user_level)
if user_level < add_level:
raise AuthError(
403, "You don't have permission to add state to the room"
)
defer.returnValue(True)
@defer.inlineCallbacks
def _can_replace_state(self, event):
current_state = yield self.store.get_current_state(
event.room_id,
event.type,
event.state_key,
)
if current_state:
current_state = current_state[0]
user_level = yield self.store.get_power_level(
event.room_id,
event.user_id,
)
if user_level:
user_level = int(user_level)
else:
user_level = 0
logger.debug(
"Checking power level for %s, %s", event.user_id, user_level
)
if current_state and hasattr(current_state, "required_power_level"):
req = current_state.required_power_level
logger.debug("Checked power level for %s, %s", event.user_id, req)
if user_level < req:
raise AuthError(
403,
"You don't have permission to change that state"
)
@defer.inlineCallbacks
def _check_redaction(self, event):
user_level = yield self.store.get_power_level(
event.room_id,
user_level = self._get_power_level_from_event_state(
event,
event.user_id,
)
if user_level:
user_level = int(user_level)
else:
user_level = 0
_, _, redact_level = yield self.store.get_ops_levels(event.room_id)
if not redact_level:
redact_level = 50
_, _, redact_level = self._get_ops_level_from_event_state(
event
)
if user_level < redact_level:
raise AuthError(
@ -379,16 +432,10 @@ class Auth(object):
"You don't have permission to redact events"
)
@defer.inlineCallbacks
def _check_power_levels(self, event):
for k, v in event.content.items():
if k == "default":
continue
# FIXME (erikj): We don't want hsob_Ts in content.
if k == "hsob_ts":
continue
user_list = event.content.get("users", {})
# Validate users
for k, v in user_list.items():
try:
self.hs.parse_userid(k)
except:
@ -399,80 +446,68 @@ class Auth(object):
except:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
current_state = yield self.store.get_current_state(
event.room_id,
event.type,
event.state_key,
)
key = (event.type, event.state_key, )
current_state = event.old_state_events.get(key)
if not current_state:
return
else:
current_state = current_state[0]
user_level = yield self.store.get_power_level(
event.room_id,
user_level = self._get_power_level_from_event_state(
event,
event.user_id,
)
if user_level:
user_level = int(user_level)
else:
user_level = 0
# Check other levels:
levels_to_check = [
("users_default", []),
("events_default", []),
("ban", []),
("redact", []),
("kick", []),
]
old_list = current_state.content
old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
(user, ["users"])
)
# FIXME (erikj)
old_people = {k: v for k, v in old_list.items() if k.startswith("@")}
new_people = {
k: v for k, v in event.content.items()
if k.startswith("@")
}
old_list = current_state.content.get("events")
new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
(ev_id, ["events"])
)
removed = set(old_people.keys()) - set(new_people.keys())
added = set(new_people.keys()) - set(old_people.keys())
same = set(old_people.keys()) & set(new_people.keys())
old_state = current_state.content
new_state = event.content
for r in removed:
if int(old_list[r]) > user_level:
raise AuthError(
403,
"You don't have permission to remove user: %s" % (r, )
)
for level_to_check, dir in levels_to_check:
old_loc = old_state
for d in dir:
old_loc = old_loc.get(d, {})
for n in added:
if int(event.content[n]) > user_level:
new_loc = new_state
for d in dir:
new_loc = new_loc.get(d, {})
if level_to_check in old_loc:
old_level = int(old_loc[level_to_check])
else:
old_level = None
if level_to_check in new_loc:
new_level = int(new_loc[level_to_check])
else:
new_level = None
if new_level is not None and old_level is not None:
if new_level == old_level:
continue
if old_level > user_level or new_level > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)
for s in same:
if int(event.content[s]) != int(old_list[s]):
if int(event.content[s]) > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)
if "default" in old_list:
old_default = int(old_list["default"])
if old_default > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater than "
"your own"
)
if "default" in event.content:
new_default = int(event.content["default"])
if new_default > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)

View File

@ -158,3 +158,37 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
for key, value in kwargs.iteritems():
err[key] = value
return err
class FederationError(RuntimeError):
""" This class is used to inform remote home servers about erroneous
PDUs they sent us.
FATAL: The remote server could not interpret the source event.
(e.g., it was missing a required field)
ERROR: The remote server interpreted the event, but it failed some other
check (e.g. auth)
WARN: The remote server accepted the event, but believes some part of it
is wrong (e.g., it referred to an invalid event)
"""
def __init__(self, level, code, reason, affected, source=None):
if level not in ["FATAL", "ERROR", "WARN"]:
raise ValueError("Level is not valid: %s" % (level,))
self.level = level
self.code = code
self.reason = reason
self.affected = affected
self.source = source
msg = "%s %s: %s" % (level, code, reason,)
super(FederationError, self).__init__(msg)
def get_dict(self):
return {
"level": self.level,
"code": self.code,
"reason": self.reason,
"affected": self.affected,
"source": self.source if self.source else self.affected,
}

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.errors import SynapseError, Codes
from synapse.util.jsonobject import JsonEncodedObject
@ -56,22 +55,26 @@ class SynapseEvent(JsonEncodedObject):
"user_id", # sender/initiator
"content", # HTTP body, JSON
"state_key",
"required_power_level",
"age_ts",
"prev_content",
"prev_state",
"replaces_state",
"redacted_because",
"origin_server_ts",
]
internal_keys = [
"is_state",
"prev_events",
"depth",
"destinations",
"origin",
"outlier",
"power_level",
"redacted",
"prev_events",
"hashes",
"signatures",
"prev_state",
"auth_events",
"state_hash",
]
required_keys = [
@ -82,8 +85,8 @@ class SynapseEvent(JsonEncodedObject):
def __init__(self, raises=True, **kwargs):
super(SynapseEvent, self).__init__(**kwargs)
if "content" in kwargs:
self.check_json(self.content, raises=raises)
# if "content" in kwargs:
# self.check_json(self.content, raises=raises)
def get_content_template(self):
""" Retrieve the JSON template for this event as a dict.
@ -114,66 +117,6 @@ class SynapseEvent(JsonEncodedObject):
"""
raise NotImplementedError("get_content_template not implemented.")
def check_json(self, content, raises=True):
"""Checks the given JSON content abides by the rules of the template.
Args:
content : A JSON object to check.
raises: True to raise a SynapseError if the check fails.
Returns:
True if the content passes the template. Returns False if the check
fails and raises=False.
Raises:
SynapseError if the check fails and raises=True.
"""
# recursively call to inspect each layer
err_msg = self._check_json(content, self.get_content_template())
if err_msg:
if raises:
raise SynapseError(400, err_msg, Codes.BAD_JSON)
else:
return False
else:
return True
def _check_json(self, content, template):
"""Check content and template matches.
If the template is a dict, each key in the dict will be validated with
the content, else it will just compare the types of content and
template. This basic type check is required because this function will
be recursively called and could be called with just strs or ints.
Args:
content: The content to validate.
template: The validation template.
Returns:
str: An error message if the validation fails, else None.
"""
if type(content) != type(template):
return "Mismatched types: %s" % template
if type(template) == dict:
for key in template:
if key not in content:
return "Missing %s key" % key
if type(content[key]) != type(template[key]):
return "Key %s is of the wrong type (got %s, want %s)" % (
key, type(content[key]), type(template[key]))
if type(content[key]) == dict:
# we must go deeper
msg = self._check_json(content[key], template[key])
if msg:
return msg
elif type(content[key]) == list:
# make sure each item type in content matches the template
for entry in content[key]:
msg = self._check_json(entry, template[key][0])
if msg:
return msg
class SynapseStateEvent(SynapseEvent):

View File

@ -16,11 +16,13 @@
from synapse.api.events.room import (
RoomTopicEvent, MessageEvent, RoomMemberEvent, FeedbackEvent,
InviteJoinEvent, RoomConfigEvent, RoomNameEvent, GenericEvent,
RoomPowerLevelsEvent, RoomJoinRulesEvent, RoomOpsPowerLevelsEvent,
RoomCreateEvent, RoomAddStateLevelEvent, RoomSendEventLevelEvent,
RoomPowerLevelsEvent, RoomJoinRulesEvent,
RoomCreateEvent,
RoomRedactionEvent,
)
from synapse.types import EventID
from synapse.util.stringutils import random_string
@ -37,9 +39,6 @@ class EventFactory(object):
RoomPowerLevelsEvent,
RoomJoinRulesEvent,
RoomCreateEvent,
RoomAddStateLevelEvent,
RoomSendEventLevelEvent,
RoomOpsPowerLevelsEvent,
RoomRedactionEvent,
]
@ -51,12 +50,26 @@ class EventFactory(object):
self.clock = hs.get_clock()
self.hs = hs
self.event_id_count = 0
def create_event_id(self):
i = str(self.event_id_count)
self.event_id_count += 1
local_part = str(int(self.clock.time())) + i + random_string(5)
e_id = EventID.create_local(local_part, self.hs)
return e_id.to_string()
def create_event(self, etype=None, **kwargs):
kwargs["type"] = etype
if "event_id" not in kwargs:
kwargs["event_id"] = "%s@%s" % (
random_string(10), self.hs.hostname
)
kwargs["event_id"] = self.create_event_id()
kwargs["origin"] = self.hs.hostname
else:
ev_id = self.hs.parse_eventid(kwargs["event_id"])
kwargs["origin"] = ev_id.domain
if "origin_server_ts" not in kwargs:
kwargs["origin_server_ts"] = int(self.clock.time_msec())

View File

@ -154,27 +154,6 @@ class RoomPowerLevelsEvent(SynapseStateEvent):
return {}
class RoomAddStateLevelEvent(SynapseStateEvent):
TYPE = "m.room.add_state_level"
def get_content_template(self):
return {}
class RoomSendEventLevelEvent(SynapseStateEvent):
TYPE = "m.room.send_event_level"
def get_content_template(self):
return {}
class RoomOpsPowerLevelsEvent(SynapseStateEvent):
TYPE = "m.room.ops_levels"
def get_content_template(self):
return {}
class RoomAliasesEvent(SynapseStateEvent):
TYPE = "m.room.aliases"

View File

@ -15,21 +15,34 @@
from .room import (
RoomMemberEvent, RoomJoinRulesEvent, RoomPowerLevelsEvent,
RoomAddStateLevelEvent, RoomSendEventLevelEvent, RoomOpsPowerLevelsEvent,
RoomAliasesEvent, RoomCreateEvent,
)
def prune_event(event):
""" Prunes the given event of all keys we don't know about or think could
potentially be dodgy.
""" Returns a pruned version of the given event, which removes all keys we
don't know about or think could potentially be dodgy.
This is used when we "redact" an event. We want to remove all fields that
the user has specified, but we do want to keep necessary information like
type, state_key etc.
"""
event_type = event.type
# Remove all extraneous fields.
event.unrecognized_keys = {}
allowed_keys = [
"event_id",
"user_id",
"room_id",
"hashes",
"signatures",
"content",
"type",
"state_key",
"depth",
"prev_events",
"prev_state",
"auth_events",
]
new_content = {}
@ -38,27 +51,33 @@ def prune_event(event):
if field in event.content:
new_content[field] = event.content[field]
if event.type == RoomMemberEvent.TYPE:
if event_type == RoomMemberEvent.TYPE:
add_fields("membership")
elif event.type == RoomCreateEvent.TYPE:
elif event_type == RoomCreateEvent.TYPE:
add_fields("creator")
elif event.type == RoomJoinRulesEvent.TYPE:
elif event_type == RoomJoinRulesEvent.TYPE:
add_fields("join_rule")
elif event.type == RoomPowerLevelsEvent.TYPE:
# TODO: Actually check these are valid user_ids etc.
add_fields("default")
for k, v in event.content.items():
if k.startswith("@") and isinstance(v, (int, long)):
new_content[k] = v
elif event.type == RoomAddStateLevelEvent.TYPE:
add_fields("level")
elif event.type == RoomSendEventLevelEvent.TYPE:
add_fields("level")
elif event.type == RoomOpsPowerLevelsEvent.TYPE:
add_fields("kick_level", "ban_level", "redact_level")
elif event.type == RoomAliasesEvent.TYPE:
elif event_type == RoomPowerLevelsEvent.TYPE:
add_fields(
"users",
"users_default",
"events",
"events_default",
"events_default",
"state_default",
"ban",
"kick",
"redact",
)
elif event_type == RoomAliasesEvent.TYPE:
add_fields("aliases")
event.content = new_content
allowed_fields = {
k: v
for k, v in event.get_full_dict().items()
if k in allowed_keys
}
return event
allowed_fields["content"] = new_content
return type(event)(**allowed_fields)

View File

@ -0,0 +1,87 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.errors import SynapseError, Codes
class EventValidator(object):
def __init__(self, hs):
pass
def validate(self, event):
"""Checks the given JSON content abides by the rules of the template.
Args:
content : A JSON object to check.
raises: True to raise a SynapseError if the check fails.
Returns:
True if the content passes the template. Returns False if the check
fails and raises=False.
Raises:
SynapseError if the check fails and raises=True.
"""
# recursively call to inspect each layer
err_msg = self._check_json_template(
event.content,
event.get_content_template()
)
if err_msg:
raise SynapseError(400, err_msg, Codes.BAD_JSON)
else:
return True
def _check_json_template(self, content, template):
"""Check content and template matches.
If the template is a dict, each key in the dict will be validated with
the content, else it will just compare the types of content and
template. This basic type check is required because this function will
be recursively called and could be called with just strs or ints.
Args:
content: The content to validate.
template: The validation template.
Returns:
str: An error message if the validation fails, else None.
"""
if type(content) != type(template):
return "Mismatched types: %s" % template
if type(template) == dict:
for key in template:
if key not in content:
return "Missing %s key" % key
if type(content[key]) != type(template[key]):
return "Key %s is of the wrong type (got %s, want %s)" % (
key, type(content[key]), type(template[key]))
if type(content[key]) == dict:
# we must go deeper
msg = self._check_json_template(
content[key],
template[key]
)
if msg:
return msg
elif type(content[key]) == list:
# make sure each item type in content matches the template
for entry in content[key]:
msg = self._check_json_template(
entry,
template[key][0]
)
if msg:
return msg

View File

@ -236,7 +236,10 @@ def setup():
f.namespace['hs'] = hs
reactor.listenTCP(config.manhole, f, interface='127.0.0.1')
hs.start_listening(config.bind_port, config.unsecure_port)
bind_port = config.bind_port
if config.no_tls:
bind_port = None
hs.start_listening(bind_port, config.unsecure_port)
if config.daemonize:
print config.pid_file

View File

@ -30,6 +30,7 @@ class ServerConfig(Config):
self.pid_file = self.abspath(args.pid_file)
self.webclient = True
self.manhole = args.manhole
self.no_tls = args.no_tls
if not args.content_addr:
host = args.server_name
@ -67,6 +68,8 @@ class ServerConfig(Config):
server_group.add_argument("--content-addr", default=None,
help="The host and scheme to use for the "
"content repository")
server_group.add_argument("--no-tls", action='store_true',
help="Don't bind to the https port.")
def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key")

View File

@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from syutil.base64util import encode_base64, decode_base64
from syutil.crypto.jsonsign import sign_json
import hashlib
import logging
logger = logging.getLogger(__name__)
def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents"""
computed_hash = _compute_content_hash(event, hash_algorithm)
if computed_hash.name not in event.hashes:
raise Exception("Algorithm %s not in hashes %s" % (
computed_hash.name, list(event.hashes)
))
message_hash_base64 = event.hashes[computed_hash.name]
try:
message_hash_bytes = decode_base64(message_hash_base64)
except:
raise Exception("Invalid base64: %s" % (message_hash_base64,))
return message_hash_bytes == computed_hash.digest()
def _compute_content_hash(event, hash_algorithm):
event_json = event.get_full_dict()
# TODO: We need to sign the JSON that is going out via fedaration.
event_json.pop("age_ts", None)
event_json.pop("unsigned", None)
event_json.pop("signatures", None)
event_json.pop("hashes", None)
event_json_bytes = encode_canonical_json(event_json)
return hash_algorithm(event_json_bytes)
def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
tmp_event = prune_event(event)
event_json = tmp_event.get_dict()
event_json.pop("signatures", None)
event_json.pop("age_ts", None)
event_json.pop("unsigned", None)
event_json_bytes = encode_canonical_json(event_json)
hashed = hash_algorithm(event_json_bytes)
return (hashed.name, hashed.digest())
def compute_event_signature(event, signature_name, signing_key):
tmp_event = prune_event(event)
redact_json = tmp_event.get_full_dict()
redact_json.pop("signatures", None)
redact_json.pop("age_ts", None)
redact_json.pop("unsigned", None)
logger.debug("Signing event: %s", redact_json)
redact_json = sign_json(redact_json, signature_name, signing_key)
return redact_json["signatures"]
def add_hashes_and_signatures(event, signature_name, signing_key,
hash_algorithm=hashlib.sha256):
if hasattr(event, "old_state_events"):
state_json_bytes = encode_canonical_json(
[e.event_id for e in event.old_state_events.values()]
)
hashed = hash_algorithm(state_json_bytes)
event.state_hash = {
hashed.name: encode_base64(hashed.digest())
}
hashed = _compute_content_hash(event, hash_algorithm=hash_algorithm)
if not hasattr(event, "hashes"):
event.hashes = {}
event.hashes[hashed.name] = encode_base64(hashed.digest())
event.signatures = compute_event_signature(
event,
signature_name=signature_name,
signing_key=signing_key,
)

View File

@ -18,50 +18,25 @@ from .units import Pdu
import copy
def decode_event_id(event_id, server_name):
parts = event_id.split("@")
if len(parts) < 2:
return (event_id, server_name)
else:
return (parts[0], "".join(parts[1:]))
def encode_event_id(pdu_id, origin):
return "%s@%s" % (pdu_id, origin)
class PduCodec(object):
def __init__(self, hs):
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
self.event_factory = hs.get_event_factory()
self.clock = hs.get_clock()
self.hs = hs
def event_from_pdu(self, pdu):
kwargs = {}
kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
kwargs["room_id"] = pdu.context
kwargs["etype"] = pdu.pdu_type
kwargs["prev_events"] = [
encode_event_id(p[0], p[1]) for p in pdu.prev_pdus
]
if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
kwargs["prev_state"] = encode_event_id(
pdu.prev_state_id, pdu.prev_state_origin
)
kwargs["etype"] = pdu.type
kwargs.update({
k: v
for k, v in pdu.get_full_dict().items()
if k not in [
"pdu_id",
"context",
"pdu_type",
"prev_pdus",
"prev_state_id",
"prev_state_origin",
"type",
]
})
@ -70,33 +45,10 @@ class PduCodec(object):
def pdu_from_event(self, event):
d = event.get_full_dict()
d["pdu_id"], d["origin"] = decode_event_id(
event.event_id, self.server_name
)
d["context"] = event.room_id
d["pdu_type"] = event.type
if hasattr(event, "prev_events"):
d["prev_pdus"] = [
decode_event_id(e, self.server_name)
for e in event.prev_events
]
if hasattr(event, "prev_state"):
d["prev_state_id"], d["prev_state_origin"] = (
decode_event_id(event.prev_state, self.server_name)
)
if hasattr(event, "state_key"):
d["is_state"] = True
kwargs = copy.deepcopy(event.unrecognized_keys)
kwargs.update({
k: v for k, v in d.items()
if k not in ["event_id", "room_id", "type", "prev_events"]
})
if "origin_server_ts" not in kwargs:
kwargs["origin_server_ts"] = int(self.clock.time_msec())
return Pdu(**kwargs)
pdu = Pdu(**kwargs)
return pdu

View File

@ -21,8 +21,6 @@ These actions are mostly only used by the :py:mod:`.replication` module.
from twisted.internet import defer
from .units import Pdu
from synapse.util.logutils import log_function
import json
@ -32,76 +30,6 @@ import logging
logger = logging.getLogger(__name__)
class PduActions(object):
""" Defines persistence actions that relate to handling PDUs.
"""
def __init__(self, datastore):
self.store = datastore
@log_function
def mark_as_processed(self, pdu):
""" Persist the fact that we have fully processed the given `Pdu`
Returns:
Deferred
"""
return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin)
@defer.inlineCallbacks
@log_function
def after_transaction(self, transaction_id, destination, origin):
""" Returns all `Pdu`s that we sent to the given remote home server
after a given transaction id.
Returns:
Deferred: Results in a list of `Pdu`s
"""
results = yield self.store.get_pdus_after_transaction(
transaction_id,
destination
)
defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
@defer.inlineCallbacks
@log_function
def get_all_pdus_from_context(self, context):
results = yield self.store.get_all_pdus_from_context(context)
defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
@defer.inlineCallbacks
@log_function
def backfill(self, context, pdu_list, limit):
""" For a given list of PDU id and origins return the proceeding
`limit` `Pdu`s in the given `context`.
Returns:
Deferred: Results in a list of `Pdu`s.
"""
results = yield self.store.get_backfill(
context, pdu_list, limit
)
defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
@log_function
def is_new(self, pdu):
""" When we receive a `Pdu` from a remote home server, we want to
figure out whether it is `new`, i.e. it is not some historic PDU that
we haven't seen simply because we haven't backfilled back that far.
Returns:
Deferred: Results in a `bool`
"""
return self.store.is_pdu_new(
pdu_id=pdu.pdu_id,
origin=pdu.origin,
context=pdu.context,
depth=pdu.depth
)
class TransactionActions(object):
""" Defines persistence actions that relate to handling Transactions.
"""
@ -158,7 +86,6 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.destination,
transaction.origin_server_ts,
[(p["pdu_id"], p["origin"]) for p in transaction.pdus]
)
@log_function

View File

@ -21,7 +21,7 @@ from twisted.internet import defer
from .units import Transaction, Pdu, Edu
from .persistence import PduActions, TransactionActions
from .persistence import TransactionActions
from synapse.util.logutils import log_function
@ -57,7 +57,7 @@ class ReplicationLayer(object):
self.transport_layer.register_request_handler(self)
self.store = hs.get_datastore()
self.pdu_actions = PduActions(self.store)
# self.pdu_actions = PduActions(self.store)
self.transaction_actions = TransactionActions(self.store)
self._transaction_queue = _TransactionQueue(
@ -81,7 +81,7 @@ class ReplicationLayer(object):
def register_edu_handler(self, edu_type, handler):
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type))
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
@ -102,24 +102,17 @@ class ReplicationLayer(object):
object to encode as JSON.
"""
if query_type in self.query_handlers:
raise KeyError("Already have a Query handler for %s" % (query_type))
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@defer.inlineCallbacks
@log_function
def send_pdu(self, pdu):
"""Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others.
This will fill out various attributes on the PDU object, e.g. the
`prev_pdus` key.
*Note:* The home server should always call `send_pdu` even if it knows
that it does not need to be replicated to other home servers. This is
in case e.g. someone else joins via a remote home server and then
backfills.
TODO: Figure out when we should actually resolve the deferred.
Args:
@ -132,18 +125,15 @@ class ReplicationLayer(object):
order = self._order
self._order += 1
logger.debug("[%s] Persisting PDU", pdu.pdu_id)
# Save *before* trying to send
yield self.store.persist_event(pdu=pdu)
logger.debug("[%s] Persisted PDU", pdu.pdu_id)
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, order)
logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id)
logger.debug(
"[%s] transaction_layer.enqueue_pdu... done",
pdu.event_id
)
@log_function
def send_edu(self, destination, edu_type, content):
@ -158,6 +148,11 @@ class ReplicationLayer(object):
self._transaction_queue.enqueue_edu(edu)
return defer.succeed(None)
@log_function
def send_failure(self, failure, destination):
self._transaction_queue.enqueue_failure(failure, destination)
return defer.succeed(None)
@log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=True):
@ -181,7 +176,7 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
def backfill(self, dest, context, limit):
def backfill(self, dest, context, limit, extremities):
"""Requests some more historic PDUs for the given context from the
given destination server.
@ -189,12 +184,12 @@ class ReplicationLayer(object):
dest (str): The remote home server to ask.
context (str): The context to backfill.
limit (int): The maximum number of PDUs to return.
extremities (list): List of PDU id and origins of the first pdus
we have seen from the context
Returns:
Deferred: Results in the received PDUs.
"""
extremities = yield self.store.get_oldest_pdus_in_context(context)
logger.debug("backfill extrem=%s", extremities)
# If there are no extremeties then we've (probably) reached the start.
@ -210,13 +205,13 @@ class ReplicationLayer(object):
pdus = [Pdu(outlier=False, **p) for p in transaction.pdus]
for pdu in pdus:
yield self._handle_new_pdu(pdu, backfilled=True)
yield self._handle_new_pdu(dest, pdu, backfilled=True)
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False):
def get_pdu(self, destination, event_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home
server.
@ -225,7 +220,7 @@ class ReplicationLayer(object):
Args:
destination (str): Which home server to query
pdu_origin (str): The home server that originally sent the pdu.
pdu_id (str)
event_id (str)
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
@ -234,8 +229,9 @@ class ReplicationLayer(object):
Deferred: Results in the requested PDU.
"""
transaction_data = yield self.transport_layer.get_pdu(
destination, pdu_origin, pdu_id)
transaction_data = yield self.transport_layer.get_event(
destination, event_id
)
transaction = Transaction(**transaction_data)
@ -244,13 +240,13 @@ class ReplicationLayer(object):
pdu = None
if pdu_list:
pdu = pdu_list[0]
yield self._handle_new_pdu(pdu)
yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
def get_state_for_context(self, destination, context):
def get_state_for_context(self, destination, context, event_id=None):
"""Requests all of the `current` state PDUs for a given context from
a remote home server.
@ -263,29 +259,25 @@ class ReplicationLayer(object):
"""
transaction_data = yield self.transport_layer.get_context_state(
destination, context)
destination,
context,
event_id=event_id,
)
transaction = Transaction(**transaction_data)
pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
for pdu in pdus:
yield self._handle_new_pdu(pdu)
yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def on_context_pdus_request(self, context):
pdus = yield self.pdu_actions.get_all_pdus_from_context(
context
def on_backfill_request(self, origin, context, versions, limit):
pdus = yield self.handler.on_backfill_request(
origin, context, versions, limit
)
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, context, versions, limit):
pdus = yield self.pdu_actions.backfill(context, versions, limit)
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@ -295,6 +287,10 @@ class ReplicationLayer(object):
transaction = Transaction(**transaction_data)
for p in transaction.pdus:
if "unsigned" in p:
unsigned = p["unsigned"]
if "age" in unsigned:
p["age"] = unsigned["age"]
if "age" in p:
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
del p["age"]
@ -315,11 +311,15 @@ class ReplicationLayer(object):
dl = []
for pdu in pdu_list:
dl.append(self._handle_new_pdu(pdu))
dl.append(self._handle_new_pdu(transaction.origin, pdu))
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu(transaction.origin, edu.edu_type, edu.content)
self.received_edu(
transaction.origin,
edu.edu_type,
edu.content
)
results = yield defer.DeferredList(dl)
@ -347,20 +347,22 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
def on_context_state_request(self, context):
results = yield self.store.get_current_state_for_context(
context
)
def on_context_state_request(self, origin, context, event_id):
if event_id:
pdus = yield self.handler.get_state_for_pdu(
origin,
context,
event_id,
)
else:
raise NotImplementedError("Specify an event")
logger.debug("Context returning %d results", len(results))
pdus = [Pdu.from_pdu_tuple(p) for p in results]
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
def on_pdu_request(self, pdu_origin, pdu_id):
pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin)
def on_pdu_request(self, origin, event_id):
pdu = yield self._get_persisted_pdu(origin, event_id)
if pdu:
defer.returnValue(
@ -372,20 +374,7 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
def on_pull_request(self, origin, versions):
transaction_id = max([int(v) for v in versions])
response = yield self.pdu_actions.after_transaction(
transaction_id,
origin,
self.server_name
)
if not response:
response = []
defer.returnValue(
(200, self._transaction_from_pdus(response).get_dict())
)
raise NotImplementedError("Pull transacions not implemented")
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
@ -393,82 +382,183 @@ class ReplicationLayer(object):
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
defer.returnValue((404, "No handler for Query type '%s'"
% (query_type)
))
defer.returnValue(
(404, "No handler for Query type '%s'" % (query_type, ))
)
@defer.inlineCallbacks
def on_make_join_request(self, context, user_id):
pdu = yield self.handler.on_make_join_request(context, user_id)
defer.returnValue({
"event": pdu.get_dict(),
})
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
pdu = Pdu(**content)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
defer.returnValue(
(
200,
{
"event": ret_pdu.get_dict(),
}
)
)
@defer.inlineCallbacks
def on_send_join_request(self, origin, content):
pdu = Pdu(**content)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
defer.returnValue((200, {
"state": [p.get_dict() for p in res_pdus["state"]],
"auth_chain": [p.get_dict() for p in res_pdus["auth_chain"]],
}))
@defer.inlineCallbacks
def on_event_auth(self, origin, context, event_id):
auth_pdus = yield self.handler.on_event_auth(event_id)
defer.returnValue(
(
200,
{
"auth_chain": [a.get_dict() for a in auth_pdus],
}
)
)
@defer.inlineCallbacks
def make_join(self, destination, context, user_id):
ret = yield self.transport_layer.make_join(
destination=destination,
context=context,
user_id=user_id,
)
pdu_dict = ret["event"]
logger.debug("Got response to make_join: %s", pdu_dict)
defer.returnValue(Pdu(**pdu_dict))
@defer.inlineCallbacks
def send_join(self, destination, pdu):
_, content = yield self.transport_layer.send_join(
destination,
pdu.room_id,
pdu.event_id,
pdu.get_dict(),
)
logger.debug("Got content: %s", content)
state = [Pdu(outlier=True, **p) for p in content.get("state", [])]
for pdu in state:
yield self._handle_new_pdu(destination, pdu)
auth_chain = [
Pdu(outlier=True, **p) for p in content.get("auth_chain", [])
]
for pdu in auth_chain:
yield self._handle_new_pdu(destination, pdu)
defer.returnValue(state)
@defer.inlineCallbacks
def send_invite(self, destination, context, event_id, pdu):
code, content = yield self.transport_layer.send_invite(
destination=destination,
context=context,
event_id=event_id,
content=pdu.get_dict(),
)
pdu_dict = content["event"]
logger.debug("Got response to send_invite: %s", pdu_dict)
defer.returnValue(Pdu(**pdu_dict))
@log_function
def _get_persisted_pdu(self, pdu_id, pdu_origin):
def _get_persisted_pdu(self, origin, event_id):
""" Get a PDU from the database with given origin and id.
Returns:
Deferred: Results in a `Pdu`.
"""
pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin)
defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple))
return self.handler.get_persisted_pdu(origin, event_id)
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
pdus = [p.get_dict() for p in pdu_list]
time_now = self._clock.time_msec()
for p in pdus:
if "age_ts" in pdus:
p["age"] = int(self.clock.time_msec()) - p["age_ts"]
if "age_ts" in p:
age = time_now - p["age_ts"]
p.setdefault("unsigned", {})["age"] = int(age)
del p["age_ts"]
return Transaction(
origin=self.server_name,
pdus=pdus,
origin_server_ts=int(self._clock.time_msec()),
origin_server_ts=int(time_now),
destination=None,
)
@defer.inlineCallbacks
@log_function
def _handle_new_pdu(self, pdu, backfilled=False):
def _handle_new_pdu(self, origin, pdu, backfilled=False):
# We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
existing = yield self._get_persisted_pdu(origin, pdu.event_id)
if existing and (not existing.outlier or pdu.outlier):
logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin)
logger.debug("Already seen pdu %s", pdu.event_id)
defer.returnValue({})
return
state = None
# Get missing pdus if necessary.
is_new = yield self.pdu_actions.is_new(pdu)
if is_new and not pdu.outlier:
if not pdu.outlier:
# We only backfill backwards to the min depth.
min_depth = yield self.store.get_min_depth_for_context(pdu.context)
min_depth = yield self.handler.get_min_depth_for_context(
pdu.room_id
)
if min_depth and pdu.depth > min_depth:
for pdu_id, origin in pdu.prev_pdus:
exists = yield self._get_persisted_pdu(pdu_id, origin)
for event_id, hashes in pdu.prev_events:
exists = yield self._get_persisted_pdu(origin, event_id)
if not exists:
logger.debug("Requesting pdu %s %s", pdu_id, origin)
logger.debug("Requesting pdu %s", event_id)
try:
yield self.get_pdu(
pdu.origin,
pdu_id=pdu_id,
pdu_origin=origin
event_id=event_id,
)
logger.debug("Processed pdu %s %s", pdu_id, origin)
logger.debug("Processed pdu %s", event_id)
except:
# TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU")
# Persist the Pdu, but don't mark it as processed yet.
yield self.store.persist_event(pdu=pdu)
else:
# We need to get the state at this event, since we have reached
# a backward extremity edge.
state = yield self.get_state_for_context(
origin, pdu.room_id, pdu.event_id,
)
if not backfilled:
ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled)
ret = yield self.handler.on_receive_pdu(
pdu,
backfilled=backfilled,
state=state,
)
else:
ret = None
yield self.pdu_actions.mark_as_processed(pdu)
# yield self.pdu_actions.mark_as_processed(pdu)
defer.returnValue(ret)
@ -476,14 +566,6 @@ class ReplicationLayer(object):
return "<ReplicationLayer(%s)>" % self.server_name
class ReplicationHandler(object):
"""This defines the methods that the :py:class:`.ReplicationLayer` will
use to communicate with the rest of the home server.
"""
def on_receive_pdu(self, pdu):
raise NotImplementedError("on_receive_pdu")
class _TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
a time for a given destination.
@ -509,6 +591,9 @@ class _TransactionQueue(object):
# destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = {}
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
# HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec())
@ -561,6 +646,18 @@ class _TransactionQueue(object):
return deferred
@defer.inlineCallbacks
def enqueue_failure(self, failure, destination):
deferred = defer.Deferred()
self.pending_failures_by_dest.setdefault(
destination, []
).append(
(failure, deferred)
)
yield deferred
@defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
@ -570,8 +667,9 @@ class _TransactionQueue(object):
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if not pending_pdus and not pending_edus:
if not pending_pdus and not pending_edus and not pending_failures:
return
logger.debug("TX [%s] Attempting new transaction", destination)
@ -581,7 +679,11 @@ class _TransactionQueue(object):
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
deferreds = [x[1] for x in pending_pdus + pending_edus]
failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
try:
self.pending_transactions[destination] = 1
@ -589,12 +691,13 @@ class _TransactionQueue(object):
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
origin_server_ts=self._clock.time_msec(),
origin_server_ts=int(self._clock.time_msec()),
transaction_id=str(self._next_txn_id),
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
self._next_txn_id += 1
@ -614,7 +717,9 @@ class _TransactionQueue(object):
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
p["age"] = now - int(p["age_ts"])
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
code, response = yield self.transport_layer.send_transaction(

View File

@ -72,7 +72,7 @@ class TransportLayer(object):
self.received_handler = None
@log_function
def get_context_state(self, destination, context):
def get_context_state(self, destination, context, event_id=None):
""" Requests all state for a given context (i.e. room) from the
given server.
@ -89,54 +89,62 @@ class TransportLayer(object):
subpath = "/state/%s/" % context
return self._do_request_for_transaction(destination, subpath)
args = {}
if event_id:
args["event_id"] = event_id
return self._do_request_for_transaction(
destination, subpath, args=args
)
@log_function
def get_pdu(self, destination, pdu_origin, pdu_id):
def get_event(self, destination, event_id):
""" Requests the pdu with give id and origin from the given server.
Args:
destination (str): The host name of the remote home server we want
to get the state from.
pdu_origin (str): The home server which created the PDU.
pdu_id (str): The id of the PDU being requested.
event_id (str): The id of the event being requested.
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s",
destination, pdu_origin, pdu_id)
logger.debug("get_pdu dest=%s, event_id=%s",
destination, event_id)
subpath = "/pdu/%s/%s/" % (pdu_origin, pdu_id)
subpath = "/event/%s/" % (event_id, )
return self._do_request_for_transaction(destination, subpath)
@log_function
def backfill(self, dest, context, pdu_tuples, limit):
def backfill(self, dest, context, event_tuples, limit):
""" Requests `limit` previous PDUs in a given context before list of
PDUs.
Args:
dest (str)
context (str)
pdu_tuples (list)
event_tuples (list)
limt (int)
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug(
"backfill dest=%s, context=%s, pdu_tuples=%s, limit=%s",
dest, context, repr(pdu_tuples), str(limit)
"backfill dest=%s, context=%s, event_tuples=%s, limit=%s",
dest, context, repr(event_tuples), str(limit)
)
if not pdu_tuples:
if not event_tuples:
# TODO: raise?
return
subpath = "/backfill/%s/" % context
subpath = "/backfill/%s/" % (context,)
args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
args["limit"] = limit
args = {
"v": event_tuples,
"limit": limit,
}
return self._do_request_for_transaction(
dest,
@ -197,6 +205,72 @@ class TransportLayer(object):
defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def make_join(self, destination, context, user_id, retry_on_dns_fail=True):
path = PREFIX + "/make_join/%s/%s" % (context, user_id,)
response = yield self.client.get_json(
destination=destination,
path=path,
retry_on_dns_fail=retry_on_dns_fail,
)
defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def send_join(self, destination, context, event_id, content):
path = PREFIX + "/send_join/%s/%s" % (
context,
event_id,
)
code, content = yield self.client.put_json(
destination=destination,
path=path,
data=content,
)
if not 200 <= code < 300:
raise RuntimeError("Got %d from send_join", code)
defer.returnValue(json.loads(content))
@defer.inlineCallbacks
@log_function
def send_invite(self, destination, context, event_id, content):
path = PREFIX + "/invite/%s/%s" % (
context,
event_id,
)
code, content = yield self.client.put_json(
destination=destination,
path=path,
data=content,
)
if not 200 <= code < 300:
raise RuntimeError("Got %d from send_invite", code)
defer.returnValue(json.loads(content))
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, context, event_id):
path = PREFIX + "/event_auth/%s/%s" % (
context,
event_id,
)
response = yield self.client.get_json(
destination=destination,
path=path,
)
defer.returnValue(response)
@defer.inlineCallbacks
def _authenticate_request(self, request):
json_request = {
@ -210,7 +284,7 @@ class TransportLayer(object):
origin = None
if request.method == "PUT":
#TODO: Handle other method types? other content types?
# TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
content = json.loads(content_bytes)
@ -222,11 +296,13 @@ class TransportLayer(object):
try:
params = auth.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value):
if value.startswith("\""):
return value[1:-1]
else:
return value
origin = strip_quotes(param_dict["origin"])
key = strip_quotes(param_dict["key"])
sig = strip_quotes(param_dict["sig"])
@ -247,7 +323,7 @@ class TransportLayer(object):
if auth.startswith("X-Matrix"):
(origin, key, sig) = parse_auth_header(auth)
json_request["origin"] = origin
json_request["signatures"].setdefault(origin,{})[key] = sig
json_request["signatures"].setdefault(origin, {})[key] = sig
if not json_request["signatures"]:
raise SynapseError(
@ -313,10 +389,10 @@ class TransportLayer(object):
# data_id pair.
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
re.compile("^" + PREFIX + "/event/([^/]*)/$"),
self._with_authentication(
lambda origin, content, query, pdu_origin, pdu_id:
handler.on_pdu_request(pdu_origin, pdu_id)
lambda origin, content, query, event_id:
handler.on_pdu_request(origin, event_id)
)
)
@ -326,7 +402,11 @@ class TransportLayer(object):
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
self._with_authentication(
lambda origin, content, query, context:
handler.on_context_state_request(context)
handler.on_context_state_request(
origin,
context,
query.get("event_id", [None])[0],
)
)
)
@ -336,20 +416,11 @@ class TransportLayer(object):
self._with_authentication(
lambda origin, content, query, context:
self._on_backfill_request(
context, query["v"], query["limit"]
origin, context, query["v"], query["limit"]
)
)
)
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/context/([^/]*)/$"),
self._with_authentication(
lambda origin, content, query, context:
handler.on_context_pdus_request(context)
)
)
# This is when we receive a server-server Query
self.server.register_path(
"GET",
@ -362,6 +433,50 @@ class TransportLayer(object):
)
)
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
self._with_authentication(
lambda origin, content, query, context, user_id:
self._on_make_join_request(
origin, content, query, context, user_id
)
)
)
self.server.register_path(
"GET",
re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"),
self._with_authentication(
lambda origin, content, query, context, event_id:
handler.on_event_auth(
origin, context, event_id,
)
)
)
self.server.register_path(
"PUT",
re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),
self._with_authentication(
lambda origin, content, query, context, event_id:
self._on_send_join_request(
origin, content, query,
)
)
)
self.server.register_path(
"PUT",
re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"),
self._with_authentication(
lambda origin, content, query, context, event_id:
self._on_invite_request(
origin, content, query,
)
)
)
@defer.inlineCallbacks
@log_function
def _on_send_request(self, origin, content, query, transaction_id):
@ -402,7 +517,8 @@ class TransportLayer(object):
return
try:
code, response = yield self.received_handler.on_incoming_transaction(
handler = self.received_handler
code, response = yield handler.on_incoming_transaction(
transaction_data
)
except:
@ -440,7 +556,7 @@ class TransportLayer(object):
defer.returnValue(data)
@log_function
def _on_backfill_request(self, context, v_list, limits):
def _on_backfill_request(self, origin, context, v_list, limits):
if not limits:
return defer.succeed(
(400, {"error": "Did not include limit param"})
@ -448,124 +564,34 @@ class TransportLayer(object):
limit = int(limits[-1])
versions = [v.split(",", 1) for v in v_list]
versions = v_list
return self.request_handler.on_backfill_request(
context, versions, limit)
origin, context, versions, limit
)
@defer.inlineCallbacks
@log_function
def _on_make_join_request(self, origin, content, query, context, user_id):
content = yield self.request_handler.on_make_join_request(
context, user_id,
)
defer.returnValue((200, content))
class TransportReceivedHandler(object):
""" Callbacks used when we receive a transaction
"""
def on_incoming_transaction(self, transaction):
""" Called on PUT /send/<transaction_id>, or on response to a request
that we sent (e.g. a backfill request)
@defer.inlineCallbacks
@log_function
def _on_send_join_request(self, origin, content, query):
content = yield self.request_handler.on_send_join_request(
origin, content,
)
Args:
transaction (synapse.transaction.Transaction): The transaction that
was sent to us.
defer.returnValue((200, content))
Returns:
twisted.internet.defer.Deferred: A deferred that gets fired when
the transaction has finished being processed.
@defer.inlineCallbacks
@log_function
def _on_invite_request(self, origin, content, query):
content = yield self.request_handler.on_invite_request(
origin, content,
)
The result should be a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.
On errors, the dict should have an `error` key with a brief message
of what went wrong.
"""
pass
class TransportRequestHandler(object):
""" Handlers used when someone want's data from us
"""
def on_pull_request(self, versions):
""" Called on GET /pull/?v=...
This is hit when a remote home server wants to get all data
after a given transaction. Mainly used when a home server comes back
online and wants to get everything it has missed.
Args:
versions (list): A list of transaction_ids that should be used to
determine what PDUs the remote side have not yet seen.
Returns:
Deferred: Resultsin a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.
On errors, the dict should have an `error` key with a brief message
of what went wrong.
"""
pass
def on_pdu_request(self, pdu_origin, pdu_id):
""" Called on GET /pdu/<pdu_origin>/<pdu_id>/
Someone wants a particular PDU. This PDU may or may not have originated
from us.
Args:
pdu_origin (str)
pdu_id (str)
Returns:
Deferred: Resultsin a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.
On errors, the dict should have an `error` key with a brief message
of what went wrong.
"""
pass
def on_context_state_request(self, context):
""" Called on GET /state/<context>/
Gets hit when someone wants all the *current* state for a given
contexts.
Args:
context (str): The name of the context that we're interested in.
Returns:
twisted.internet.defer.Deferred: A deferred that gets fired when
the transaction has finished being processed.
The result should be a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.
On errors, the dict should have an `error` key with a brief message
of what went wrong.
"""
pass
def on_backfill_request(self, context, versions, limit):
""" Called on GET /backfill/<context>/?v=...&limit=...
Gets hit when we want to backfill backwards on a given context from
the given point.
Args:
context (str): The context to backfill
versions (list): A list of 2-tuples representing where to backfill
from, in the form `(pdu_id, origin)`
limit (int): How many pdus to return.
Returns:
Deferred: Results in a tuple in the form of
`(response_code, respond_body)`, where `response_body` is a python
dict that will get serialized to JSON.
On errors, the dict should have an `error` key with a brief message
of what went wrong.
"""
pass
def on_query_request(self):
""" Called on a GET /query/<query_type> request. """
defer.returnValue((200, content))

View File

@ -20,8 +20,6 @@ server protocol.
from synapse.util.jsonobject import JsonEncodedObject
import logging
import json
import copy
logger = logging.getLogger(__name__)
@ -33,13 +31,13 @@ class Pdu(JsonEncodedObject):
A Pdu can be classified as "state". For a given context, we can efficiently
retrieve all state pdu's that haven't been clobbered. Clobbering is done
via a unique constraint on the tuple (context, pdu_type, state_key). A pdu
via a unique constraint on the tuple (context, type, state_key). A pdu
is a state pdu if `is_state` is True.
Example pdu::
{
"pdu_id": "78c",
"event_id": "$78c:example.com",
"origin_server_ts": 1404835423000,
"origin": "bar",
"prev_ids": [
@ -52,24 +50,21 @@ class Pdu(JsonEncodedObject):
"""
valid_keys = [
"pdu_id",
"context",
"event_id",
"room_id",
"origin",
"origin_server_ts",
"pdu_type",
"type",
"destinations",
"transaction_id",
"prev_pdus",
"prev_events",
"depth",
"content",
"outlier",
"is_state", # Below this are keys valid only for State Pdus.
"state_key",
"power_level",
"prev_state_id",
"prev_state_origin",
"required_power_level",
"hashes",
"user_id",
"auth_events",
"signatures", # Below this are keys valid only for State Pdus.
"state_key",
"prev_state",
]
internal_keys = [
@ -79,61 +74,28 @@ class Pdu(JsonEncodedObject):
]
required_keys = [
"pdu_id",
"context",
"event_id",
"room_id",
"origin",
"origin_server_ts",
"pdu_type",
"type",
"content",
]
# TODO: We need to make this properly load content rather than
# just leaving it as a dict. (OR DO WE?!)
def __init__(self, destinations=[], is_state=False, prev_pdus=[],
outlier=False, **kwargs):
if is_state:
for required_key in ["state_key"]:
if required_key not in kwargs:
raise RuntimeError("Key %s is required" % required_key)
def __init__(self, destinations=[], prev_events=[],
outlier=False, hashes={}, signatures={}, **kwargs):
super(Pdu, self).__init__(
destinations=destinations,
is_state=is_state,
prev_pdus=prev_pdus,
prev_events=prev_events,
outlier=outlier,
hashes=hashes,
signatures=signatures,
**kwargs
)
@classmethod
def from_pdu_tuple(cls, pdu_tuple):
""" Converts a PduTuple to a Pdu
Args:
pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to
convert
Returns:
Pdu
"""
if pdu_tuple:
d = copy.copy(pdu_tuple.pdu_entry._asdict())
d["origin_server_ts"] = d.pop("ts")
d["content"] = json.loads(d["content_json"])
del d["content_json"]
args = {f: d[f] for f in cls.valid_keys if f in d}
if "unrecognized_keys" in d and d["unrecognized_keys"]:
args.update(json.loads(d["unrecognized_keys"]))
return Pdu(
prev_pdus=pdu_tuple.prev_pdu_list,
**args
)
else:
return None
def __str__(self):
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
@ -193,6 +155,7 @@ class Transaction(JsonEncodedObject):
"edus",
"transaction_id",
"destination",
"pdu_failures",
]
internal_keys = [
@ -229,7 +192,9 @@ class Transaction(JsonEncodedObject):
transaction_id and origin_server_ts keys.
"""
if "origin_server_ts" not in kwargs:
raise KeyError("Require 'origin_server_ts' to construct a Transaction")
raise KeyError(
"Require 'origin_server_ts' to construct a Transaction"
)
if "transaction_id" not in kwargs:
raise KeyError(
"Require 'transaction_id' to construct a Transaction"
@ -241,6 +206,3 @@ class Transaction(JsonEncodedObject):
kwargs["pdus"] = [p.get_dict() for p in pdus]
return Transaction(**kwargs)

View File

@ -14,7 +14,18 @@
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.events.room import RoomMemberEvent
from synapse.api.constants import Membership
import logging
logger = logging.getLogger(__name__)
class BaseHandler(object):
@ -30,6 +41,9 @@ class BaseHandler(object):
self.clock = hs.get_clock()
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
def ratelimit(self, user_id):
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message(
@ -44,16 +58,58 @@ class BaseHandler(object):
@defer.inlineCallbacks
def _on_new_room_event(self, event, snapshot, extra_destinations=[],
extra_users=[]):
extra_users=[], suppress_auth=False,
do_invite_host=None):
yield run_on_reactor()
snapshot.fill_out_prev_events(event)
yield self.state_handler.annotate_state_groups(event)
yield self.auth.add_auth_events(event)
logger.debug("Signing event...")
add_hashes_and_signatures(
event, self.server_name, self.signing_key
)
logger.debug("Signed event.")
if not suppress_auth:
logger.debug("Authing...")
self.auth.check(event, raises=True)
logger.debug("Authed")
else:
logger.debug("Suppressed auth.")
if do_invite_host:
federation_handler = self.hs.get_handlers().federation_handler
invite_event = yield federation_handler.send_invite(
do_invite_host,
event
)
# FIXME: We need to check if the remote changed anything else
event.signatures = invite_event.signatures
yield self.store.persist_event(event)
destinations = set(extra_destinations)
# Send a PDU to all hosts who have joined the room.
destinations.update((yield self.store.get_joined_hosts_for_room(
event.room_id
)))
for k, s in event.state_events.items():
try:
if k[0] == RoomMemberEvent.TYPE:
if s.content["membership"] == Membership.JOIN:
destinations.add(
self.hs.parse_userid(s.state_key).domain
)
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
event.destinations = list(destinations)
self.notifier.on_new_room_event(event, extra_users=extra_users)

View File

@ -147,10 +147,8 @@ class DirectoryHandler(BaseHandler):
content={"aliases": aliases},
)
snapshot = yield self.store.snapshot_room(
room_id=room_id,
user_id=user_id,
)
snapshot = yield self.store.snapshot_room(event)
yield self.state_handler.handle_new_event(event, snapshot)
yield self._on_new_room_event(event, snapshot, extra_users=[user_id])
yield self._on_new_room_event(
event, snapshot, extra_users=[user_id], suppress_auth=True
)

View File

@ -17,13 +17,15 @@
from ._base import BaseHandler
from synapse.api.events.room import InviteJoinEvent, RoomMemberEvent
from synapse.api.errors import AuthError, FederationError
from synapse.api.events.room import RoomMemberEvent
from synapse.api.constants import Membership
from synapse.util.logutils import log_function
from synapse.federation.pdu_codec import PduCodec
from synapse.api.errors import SynapseError
from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import compute_event_signature
from twisted.internet import defer, reactor
from twisted.internet import defer
import logging
@ -62,6 +64,9 @@ class FederationHandler(BaseHandler):
self.pdu_codec = PduCodec(hs)
# When joining a room we need to queue any events for that room up
self.room_queues = {}
@log_function
@defer.inlineCallbacks
def handle_new_event(self, event, snapshot):
@ -78,6 +83,8 @@ class FederationHandler(BaseHandler):
processing.
"""
yield run_on_reactor()
pdu = self.pdu_codec.pdu_from_event(event)
if not hasattr(pdu, "destinations") or not pdu.destinations:
@ -87,97 +94,88 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def on_receive_pdu(self, pdu, backfilled):
def on_receive_pdu(self, pdu, backfilled, state=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it throught the StateHandler.
do auth checks and put it through the StateHandler.
"""
event = self.pdu_codec.event_from_pdu(pdu)
logger.debug("Got event: %s", event.event_id)
with (yield self.lock_manager.lock(pdu.context)):
if event.is_state and not backfilled:
is_new_state = yield self.state_handler.handle_new_state(
pdu
)
else:
is_new_state = False
if event.room_id in self.room_queues:
self.room_queues[event.room_id].append(pdu)
return
logger.debug("Processing event: %s", event.event_id)
if state:
state = [self.pdu_codec.event_from_pdu(p) for p in state]
is_new_state = yield self.state_handler.annotate_state_groups(
event,
old_state=state
)
logger.debug("Event: %s", event)
try:
self.auth.check(event, raises=True)
except AuthError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event.event_id,
)
is_new_state = is_new_state and not backfilled
# TODO: Implement something in federation that allows us to
# respond to PDU.
target_is_mine = False
if hasattr(event, "target_host"):
target_is_mine = event.target_host == self.hs.hostname
yield self.store.persist_event(
event,
backfilled,
is_new_state=is_new_state
)
if event.type == InviteJoinEvent.TYPE:
if not target_is_mine:
logger.debug("Ignoring invite/join event %s", event)
return
room = yield self.store.get_room(event.room_id)
# If we receive an invite/join event then we need to join the
# sender to the given room.
# TODO: We should probably auth this or some such
content = event.content
content.update({"membership": Membership.JOIN})
new_event = self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
state_key=event.user_id,
room_id=event.room_id,
user_id=event.user_id,
membership=Membership.JOIN,
content=content
)
yield self.hs.get_handlers().room_member_handler.change_membership(
new_event,
do_auth=False,
)
else:
with (yield self.room_lock.lock(event.room_id)):
yield self.store.persist_event(
event,
backfilled,
is_new_state=is_new_state
if not room:
# Huh, let's try and get the current state
try:
yield self.replication_layer.get_state_for_context(
event.origin, event.room_id, event.event_id,
)
room = yield self.store.get_room(event.room_id)
if not room:
# Huh, let's try and get the current state
try:
yield self.replication_layer.get_state_for_context(
event.origin, event.room_id
)
hosts = yield self.store.get_joined_hosts_for_room(
event.room_id
)
if self.hs.hostname in hosts:
try:
yield self.store.store_room(
room_id=event.room_id,
room_creator_user_id="",
is_public=False,
)
except:
pass
except:
logger.exception(
"Failed to get current state for room %s",
event.room_id
)
if not backfilled:
extra_users = []
if event.type == RoomMemberEvent.TYPE:
target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id)
extra_users.append(target_user)
yield self.notifier.on_new_room_event(
event, extra_users=extra_users
hosts = yield self.store.get_joined_hosts_for_room(
event.room_id
)
if self.hs.hostname in hosts:
try:
yield self.store.store_room(
room_id=event.room_id,
room_creator_user_id="",
is_public=False,
)
except:
pass
except:
logger.exception(
"Failed to get current state for room %s",
event.room_id
)
if not backfilled:
extra_users = []
if event.type == RoomMemberEvent.TYPE:
target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id)
extra_users.append(target_user)
yield self.notifier.on_new_room_event(
event, extra_users=extra_users
)
if event.type == RoomMemberEvent.TYPE:
if event.membership == Membership.JOIN:
@ -189,79 +187,344 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit):
pdus = yield self.replication_layer.backfill(dest, room_id, limit)
extremities = yield self.store.get_oldest_events_in_room(room_id)
pdus = yield self.replication_layer.backfill(
dest,
room_id,
limit,
extremities=extremities,
)
events = []
for pdu in pdus:
event = self.pdu_codec.event_from_pdu(pdu)
# FIXME (erikj): Not sure this actually works :/
yield self.state_handler.annotate_state_groups(event)
events.append(event)
yield self.store.persist_event(event, backfilled=True)
defer.returnValue(events)
@defer.inlineCallbacks
def send_invite(self, target_host, event):
pdu = yield self.replication_layer.send_invite(
destination=target_host,
context=event.room_id,
event_id=event.event_id,
pdu=self.pdu_codec.pdu_from_event(event)
)
defer.returnValue(self.pdu_codec.event_from_pdu(pdu))
@defer.inlineCallbacks
def on_event_auth(self, event_id):
auth = yield self.store.get_auth_chain(event_id)
defer.returnValue([self.pdu_codec.pdu_from_event(e) for e in auth])
@log_function
@defer.inlineCallbacks
def do_invite_join(self, target_host, room_id, joinee, content, snapshot):
hosts = yield self.store.get_joined_hosts_for_room(room_id)
if self.hs.hostname in hosts:
# We are already in the room.
logger.debug("We're already in the room apparently")
defer.returnValue(False)
# First get current state to see if we are already joined.
try:
yield self.replication_layer.get_state_for_context(
target_host, room_id
)
hosts = yield self.store.get_joined_hosts_for_room(room_id)
if self.hs.hostname in hosts:
# Oh, we were actually in the room already.
logger.debug("We're already in the room apparently")
defer.returnValue(False)
except Exception:
logger.exception("Failed to get current state")
new_event = self.event_factory.create_event(
etype=InviteJoinEvent.TYPE,
target_host=target_host,
room_id=room_id,
user_id=joinee,
content=content
pdu = yield self.replication_layer.make_join(
target_host,
room_id,
joinee
)
new_event.destinations = [target_host]
logger.debug("Got response to make_join: %s", pdu)
snapshot.fill_out_prev_events(new_event)
yield self.handle_new_event(new_event, snapshot)
event = self.pdu_codec.event_from_pdu(pdu)
# TODO (erikj): Time out here.
d = defer.Deferred()
self.waiting_for_join_list.setdefault((joinee, room_id), []).append(d)
reactor.callLater(10, d.cancel)
# We should assert some things.
assert(event.type == RoomMemberEvent.TYPE)
assert(event.user_id == joinee)
assert(event.state_key == joinee)
assert(event.room_id == room_id)
event.outlier = False
self.room_queues[room_id] = []
try:
yield d
except defer.CancelledError:
raise SynapseError(500, "Unable to join remote room")
event.event_id = self.event_factory.create_event_id()
event.content = content
try:
yield self.store.store_room(
room_id=room_id,
room_creator_user_id="",
is_public=False
state = yield self.replication_layer.send_join(
target_host,
self.pdu_codec.pdu_from_event(event)
)
except:
pass
state = [self.pdu_codec.event_from_pdu(p) for p in state]
logger.debug("do_invite_join state: %s", state)
is_new_state = yield self.state_handler.annotate_state_groups(
event,
old_state=state
)
logger.debug("do_invite_join event: %s", event)
try:
yield self.store.store_room(
room_id=room_id,
room_creator_user_id="",
is_public=False
)
except:
# FIXME
pass
for e in state:
# FIXME: Auth these.
e.outlier = True
yield self.state_handler.annotate_state_groups(
e,
)
yield self.store.persist_event(
e,
backfilled=False,
is_new_state=False
)
yield self.store.persist_event(
event,
backfilled=False,
is_new_state=is_new_state
)
finally:
room_queue = self.room_queues[room_id]
del self.room_queues[room_id]
for p in room_queue:
try:
yield self.on_receive_pdu(p, backfilled=False)
except:
pass
defer.returnValue(True)
@defer.inlineCallbacks
@log_function
def on_make_join_request(self, context, user_id):
event = self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
content={"membership": Membership.JOIN},
room_id=context,
user_id=user_id,
state_key=user_id,
)
snapshot = yield self.store.snapshot_room(event)
snapshot.fill_out_prev_events(event)
yield self.state_handler.annotate_state_groups(event)
yield self.auth.add_auth_events(event)
self.auth.check(event, raises=True)
pdu = self.pdu_codec.pdu_from_event(event)
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
def on_send_join_request(self, origin, pdu):
event = self.pdu_codec.event_from_pdu(pdu)
event.outlier = False
is_new_state = yield self.state_handler.annotate_state_groups(event)
self.auth.check(event, raises=True)
# FIXME (erikj): All this is duplicated above :(
yield self.store.persist_event(
event,
backfilled=False,
is_new_state=is_new_state
)
extra_users = []
if event.type == RoomMemberEvent.TYPE:
target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id)
extra_users.append(target_user)
yield self.notifier.on_new_room_event(
event, extra_users=extra_users
)
if event.type == RoomMemberEvent.TYPE:
if event.membership == Membership.JOIN:
user = self.hs.parse_userid(event.state_key)
self.distributor.fire(
"user_joined_room", user=user, room_id=event.room_id
)
new_pdu = self.pdu_codec.pdu_from_event(event)
destinations = set()
for k, s in event.state_events.items():
try:
if k[0] == RoomMemberEvent.TYPE:
if s.content["membership"] == Membership.JOIN:
destinations.add(
self.hs.parse_userid(s.state_key).domain
)
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
new_pdu.destinations = list(destinations)
yield self.replication_layer.send_pdu(new_pdu)
auth_chain = yield self.store.get_auth_chain(event.event_id)
pdu_auth_chain = [
self.pdu_codec.pdu_from_event(e)
for e in auth_chain
]
defer.returnValue({
"state": [
self.pdu_codec.pdu_from_event(e)
for e in event.state_events.values()
],
"auth_chain": pdu_auth_chain,
})
@defer.inlineCallbacks
def on_invite_request(self, origin, pdu):
event = self.pdu_codec.event_from_pdu(pdu)
event.outlier = True
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
yield self.state_handler.annotate_state_groups(event)
yield self.store.persist_event(
event,
backfilled=False,
)
target_user = self.hs.parse_userid(event.state_key)
yield self.notifier.on_new_room_event(
event, extra_users=[target_user],
)
defer.returnValue(self.pdu_codec.pdu_from_event(event))
@defer.inlineCallbacks
def get_state_for_pdu(self, origin, room_id, event_id):
yield run_on_reactor()
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
state_groups = yield self.store.get_state_groups(
[event_id]
)
if state_groups:
_, state = state_groups.items().pop()
results = {
(e.type, e.state_key): e for e in state
}
event = yield self.store.get_event(event_id)
if hasattr(event, "state_key"):
# Get previous state
if hasattr(event, "replaces_state") and event.replaces_state:
prev_event = yield self.store.get_event(
event.replaces_state
)
results[(event.type, event.state_key)] = prev_event
else:
del results[(event.type, event.state_key)]
defer.returnValue(
[
self.pdu_codec.pdu_from_event(s)
for s in results.values()
]
)
else:
defer.returnValue([])
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, context, pdu_list, limit):
in_room = yield self.auth.check_host_in_room(context, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
events = yield self.store.get_backfill_events(
context,
pdu_list,
limit
)
defer.returnValue([
self.pdu_codec.pdu_from_event(e)
for e in events
])
@defer.inlineCallbacks
@log_function
def get_persisted_pdu(self, origin, event_id):
""" Get a PDU from the database with given origin and id.
Returns:
Deferred: Results in a `Pdu`.
"""
event = yield self.store.get_event(
event_id,
allow_none=True,
)
if event:
in_room = yield self.auth.check_host_in_room(
event.room_id,
origin
)
if not in_room:
raise AuthError(403, "Host not in room.")
defer.returnValue(self.pdu_codec.pdu_from_event(event))
else:
defer.returnValue(None)
@log_function
def get_min_depth_for_context(self, context):
return self.store.get_min_depth(context)
@log_function
def _on_user_joined(self, user, room_id):
waiters = self.waiting_for_join_list.get((user.to_string(), room_id), [])
waiters = self.waiting_for_join_list.get(
(user.to_string(), room_id),
[]
)
while waiters:
waiters.pop().callback(None)

View File

@ -81,12 +81,11 @@ class MessageHandler(BaseHandler):
user = self.hs.parse_userid(event.user_id)
assert user.is_mine, "User must be our own: %s" % (user,)
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
snapshot = yield self.store.snapshot_room(event)
if not suppress_auth:
yield self.auth.check(event, snapshot, raises=True)
yield self._on_new_room_event(event, snapshot)
yield self._on_new_room_event(
event, snapshot, suppress_auth=suppress_auth
)
self.hs.get_handlers().presence_handler.bump_presence_active_time(
user
@ -142,16 +141,7 @@ class MessageHandler(BaseHandler):
SynapseError if something went wrong.
"""
snapshot = yield self.store.snapshot_room(
event.room_id,
event.user_id,
state_type=event.type,
state_key=event.state_key,
)
yield self.auth.check(event, snapshot, raises=True)
yield self.state_handler.handle_new_event(event, snapshot)
snapshot = yield self.store.snapshot_room(event)
yield self._on_new_room_event(event, snapshot)
@ -201,7 +191,7 @@ class MessageHandler(BaseHandler):
raise RoomError(
403, "Member does not meet private room rules.")
data = yield self.store.get_current_state(
data = yield self.state_handler.get_current_state(
room_id, event_type, state_key
)
defer.returnValue(data)
@ -219,9 +209,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def send_feedback(self, event):
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
yield self.auth.check(event, snapshot, raises=True)
snapshot = yield self.store.snapshot_room(event)
# store message in db
yield self._on_new_room_event(event, snapshot)
@ -239,7 +227,7 @@ class MessageHandler(BaseHandler):
yield self.auth.check_joined_room(room_id, user_id)
# TODO: This is duplicating logic from snapshot_all_rooms
current_state = yield self.store.get_current_state(room_id)
current_state = yield self.state_handler.get_current_state(room_id)
defer.returnValue([self.hs.serialize_event(c) for c in current_state])
@defer.inlineCallbacks
@ -316,7 +304,7 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(),
}
current_state = yield self.store.get_current_state(
current_state = yield self.state_handler.get_current_state(
event.room_id
)
d["state"] = [self.hs.serialize_event(c) for c in current_state]

View File

@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import Membership
from synapse.api.events.room import RoomMemberEvent
from ._base import BaseHandler
@ -153,10 +152,13 @@ class ProfileHandler(BaseHandler):
if not user.is_mine:
defer.returnValue(None)
(displayname, avatar_url) = yield defer.gatherResults([
self.store.get_profile_displayname(user.localpart),
self.store.get_profile_avatar_url(user.localpart),
])
(displayname, avatar_url) = yield defer.gatherResults(
[
self.store.get_profile_displayname(user.localpart),
self.store.get_profile_avatar_url(user.localpart),
],
consumeErrors=True
)
state["displayname"] = displayname
state["avatar_url"] = avatar_url
@ -196,10 +198,7 @@ class ProfileHandler(BaseHandler):
)
for j in joins:
snapshot = yield self.store.snapshot_room(
j.room_id, j.state_key, RoomMemberEvent.TYPE,
j.state_key
)
snapshot = yield self.store.snapshot_room(j)
content = {
"membership": j.content["membership"],
@ -218,5 +217,6 @@ class ProfileHandler(BaseHandler):
user_id=j.state_key,
)
yield self.state_handler.handle_new_event(new_event, snapshot)
yield self._on_new_room_event(new_event, snapshot)
yield self._on_new_room_event(
new_event, snapshot, suppress_auth=True
)

View File

@ -21,8 +21,7 @@ from synapse.api.constants import Membership, JoinRules
from synapse.api.errors import StoreError, SynapseError
from synapse.api.events.room import (
RoomMemberEvent, RoomCreateEvent, RoomPowerLevelsEvent,
RoomJoinRulesEvent, RoomAddStateLevelEvent, RoomTopicEvent,
RoomSendEventLevelEvent, RoomOpsPowerLevelsEvent, RoomNameEvent,
RoomTopicEvent, RoomNameEvent, RoomJoinRulesEvent,
)
from synapse.util import stringutils
from ._base import BaseHandler
@ -122,15 +121,13 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def handle_event(event):
snapshot = yield self.store.snapshot_room(
room_id=room_id,
user_id=user_id,
)
snapshot = yield self.store.snapshot_room(event)
logger.debug("Event: %s", event)
yield self.state_handler.handle_new_event(event, snapshot)
yield self._on_new_room_event(event, snapshot, extra_users=[user])
yield self._on_new_room_event(
event, snapshot, extra_users=[user], suppress_auth=True
)
for event in creation_events:
yield handle_event(event)
@ -141,7 +138,6 @@ class RoomCreationHandler(BaseHandler):
etype=RoomNameEvent.TYPE,
room_id=room_id,
user_id=user_id,
required_power_level=50,
content={"name": name},
)
@ -153,7 +149,6 @@ class RoomCreationHandler(BaseHandler):
etype=RoomTopicEvent.TYPE,
room_id=room_id,
user_id=user_id,
required_power_level=50,
content={"topic": topic},
)
@ -198,7 +193,6 @@ class RoomCreationHandler(BaseHandler):
event_keys = {
"room_id": room_id,
"user_id": creator.to_string(),
"required_power_level": 100,
}
def create(etype, **content):
@ -215,7 +209,21 @@ class RoomCreationHandler(BaseHandler):
power_levels_event = self.event_factory.create_event(
etype=RoomPowerLevelsEvent.TYPE,
content={creator.to_string(): 100, "default": 0},
content={
"users": {
creator.to_string(): 100,
},
"users_default": 0,
"events": {
RoomNameEvent.TYPE: 100,
RoomPowerLevelsEvent.TYPE: 100,
},
"events_default": 0,
"state_default": 50,
"ban": 50,
"kick": 50,
"redact": 50
},
**event_keys
)
@ -225,30 +233,10 @@ class RoomCreationHandler(BaseHandler):
join_rule=join_rule,
)
add_state_event = create(
etype=RoomAddStateLevelEvent.TYPE,
level=100,
)
send_event = create(
etype=RoomSendEventLevelEvent.TYPE,
level=0,
)
ops = create(
etype=RoomOpsPowerLevelsEvent.TYPE,
ban_level=50,
kick_level=50,
redact_level=50,
)
return [
creation_event,
power_levels_event,
join_rules_event,
add_state_event,
send_event,
ops,
]
@ -363,10 +351,8 @@ class RoomMemberHandler(BaseHandler):
"""
target_user_id = event.state_key
snapshot = yield self.store.snapshot_room(
event.room_id, event.user_id,
RoomMemberEvent.TYPE, target_user_id
)
snapshot = yield self.store.snapshot_room(event)
## TODO(markjh): get prev state from snapshot.
prev_state = yield self.store.get_room_member(
target_user_id, event.room_id
@ -375,13 +361,6 @@ class RoomMemberHandler(BaseHandler):
if prev_state:
event.content["prev"] = prev_state.membership
# if prev_state and prev_state.membership == event.membership:
# # treat this event as a NOOP.
# if do_auth: # This is mainly to fix a unit test.
# yield self.auth.check(event, raises=True)
# defer.returnValue({})
# return
room_id = event.room_id
# If we're trying to join a room then we have to do this differently
@ -391,29 +370,17 @@ class RoomMemberHandler(BaseHandler):
yield self._do_join(event, snapshot, do_auth=do_auth)
else:
# This is not a JOIN, so we can handle it normally.
if do_auth:
yield self.auth.check(event, snapshot, raises=True)
# If we're banning someone, set a req power level
if event.membership == Membership.BAN:
if not hasattr(event, "required_power_level") or event.required_power_level is None:
# Add some default required_power_level
user_level = yield self.store.get_power_level(
event.room_id,
event.user_id,
)
event.required_power_level = user_level
if prev_state and prev_state.membership == event.membership:
# double same action, treat this event as a NOOP.
defer.returnValue({})
return
yield self.state_handler.handle_new_event(event, snapshot)
yield self._do_local_membership_update(
event,
membership=event.content["membership"],
snapshot=snapshot,
do_auth=do_auth,
)
defer.returnValue({"room_id": room_id})
@ -443,10 +410,7 @@ class RoomMemberHandler(BaseHandler):
content=content,
)
snapshot = yield self.store.snapshot_room(
room_id, joinee.to_string(), RoomMemberEvent.TYPE,
joinee.to_string()
)
snapshot = yield self.store.snapshot_room(new_event)
yield self._do_join(new_event, snapshot, room_host=host, do_auth=True)
@ -502,14 +466,11 @@ class RoomMemberHandler(BaseHandler):
if not have_joined:
logger.debug("Doing normal join")
if do_auth:
yield self.auth.check(event, snapshot, raises=True)
yield self.state_handler.handle_new_event(event, snapshot)
yield self._do_local_membership_update(
event,
membership=event.content["membership"],
snapshot=snapshot,
do_auth=do_auth,
)
user = self.hs.parse_userid(event.user_id)
@ -553,26 +514,27 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue([r.room_id for r in rooms])
def _do_local_membership_update(self, event, membership, snapshot):
destinations = []
@defer.inlineCallbacks
def _do_local_membership_update(self, event, membership, snapshot,
do_auth):
# If we're inviting someone, then we should also send it to that
# HS.
target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id)
if membership == Membership.INVITE:
host = target_user.domain
destinations.append(host)
if membership == Membership.INVITE and not target_user.is_mine:
do_invite_host = target_user.domain
else:
do_invite_host = None
# Always include target domain
host = target_user.domain
destinations.append(host)
return self._on_new_room_event(
event, snapshot, extra_destinations=destinations,
extra_users=[target_user]
yield self._on_new_room_event(
event,
snapshot,
extra_users=[target_user],
suppress_auth=(not do_auth),
do_invite_host=do_invite_host,
)
class RoomListHandler(BaseHandler):
@defer.inlineCallbacks

View File

@ -18,6 +18,11 @@ from synapse.api.urls import CLIENT_PREFIX
from synapse.rest.transactions import HttpTransactionStore
import re
import logging
logger = logging.getLogger(__name__)
def client_path_pattern(path_regex):
"""Creates a regex compiled client path with the correct client path
@ -62,6 +67,8 @@ class RestServlet(object):
self.auth = hs.get_auth()
self.txns = HttpTransactionStore()
self.validator = hs.get_event_validator()
def register(self, http_server):
""" Register this servlet with the given HTTP server. """
if hasattr(self, "PATTERN"):

View File

@ -20,6 +20,12 @@ from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig
from synapse.rest.base import RestServlet, client_path_pattern
import logging
logger = logging.getLogger(__name__)
class EventStreamRestServlet(RestServlet):
PATTERN = client_path_pattern("/events$")
@ -29,18 +35,22 @@ class EventStreamRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
auth_user = yield self.auth.get_user_by_req(request)
try:
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args:
try:
timeout = int(request.args["timeout"][0])
except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.")
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args:
try:
timeout = int(request.args["timeout"][0])
except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.")
chunk = yield handler.get_stream(auth_user.to_string(), pagin_config,
timeout=timeout)
chunk = yield handler.get_stream(
auth_user.to_string(), pagin_config, timeout=timeout
)
except:
logger.exception("Event stream failed")
raise
defer.returnValue((200, chunk))

View File

@ -138,7 +138,7 @@ class RoomStateEventRestServlet(RestServlet):
raise SynapseError(
404, "Event not found.", errcode=Codes.NOT_FOUND
)
defer.returnValue((200, data[0].get_dict()["content"]))
defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key):
@ -154,6 +154,9 @@ class RoomStateEventRestServlet(RestServlet):
user_id=user.to_string(),
state_key=urllib.unquote(state_key)
)
self.validator.validate(event)
if event_type == RoomMemberEvent.TYPE:
# membership events are special
handler = self.handlers.room_member_handler
@ -188,6 +191,8 @@ class RoomSendEventRestServlet(RestServlet):
content=content
)
self.validator.validate(event)
msg_handler = self.handlers.message_handler
yield msg_handler.send_message(event)
@ -253,6 +258,9 @@ class JoinRoomAliasServlet(RestServlet):
user_id=user.to_string(),
state_key=user.to_string()
)
self.validator.validate(event)
handler = self.handlers.room_member_handler
yield handler.change_membership(event)
defer.returnValue((200, {}))
@ -424,6 +432,9 @@ class RoomMembershipRestServlet(RestServlet):
user_id=user.to_string(),
state_key=state_key
)
self.validator.validate(event)
handler = self.handlers.room_member_handler
yield handler.change_membership(event)
defer.returnValue((200, {}))
@ -461,6 +472,8 @@ class RoomRedactEventRestServlet(RestServlet):
redacts=urllib.unquote(event_id),
)
self.validator.validate(event)
msg_handler = self.handlers.message_handler
yield msg_handler.send_message(event)

View File

@ -22,13 +22,14 @@
from synapse.federation import initialize_http_replication
from synapse.api.events import serialize_event
from synapse.api.events.factory import EventFactory
from synapse.api.events.validator import EventValidator
from synapse.notifier import Notifier
from synapse.api.auth import Auth
from synapse.handlers import Handlers
from synapse.rest import RestServletFactory
from synapse.state import StateHandler
from synapse.storage import DataStore
from synapse.types import UserID, RoomAlias, RoomID
from synapse.types import UserID, RoomAlias, RoomID, EventID
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.lockutils import LockManager
@ -80,6 +81,7 @@ class BaseHomeServer(object):
'event_sources',
'ratelimiter',
'keyring',
'event_validator',
]
def __init__(self, hostname, **kwargs):
@ -143,6 +145,11 @@ class BaseHomeServer(object):
object."""
return RoomID.from_string(s, hs=self)
def parse_eventid(self, s):
"""Parse the string given by 's' as a Event ID and return a EventID
object."""
return EventID.from_string(s, hs=self)
def serialize_event(self, e):
return serialize_event(self, e)
@ -218,6 +225,9 @@ class HomeServer(BaseHomeServer):
def build_keyring(self):
return Keyring(self)
def build_event_validator(self):
return EventValidator(self)
def register_servlets(self):
""" Register all servlets associated with this HomeServer.
"""

View File

@ -16,11 +16,13 @@
from twisted.internet import defer
from synapse.federation.pdu_codec import encode_event_id, decode_event_id
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.api.events.room import RoomPowerLevelsEvent
from collections import namedtuple
import copy
import logging
import hashlib
@ -35,230 +37,169 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
class StateHandler(object):
""" Repsonsible for doing state conflict resolution.
""" Responsible for doing state conflict resolution.
"""
def __init__(self, hs):
self.store = hs.get_datastore()
self._replication = hs.get_replication_layer()
self.server_name = hs.hostname
@defer.inlineCallbacks
@log_function
def handle_new_event(self, event, snapshot):
""" Given an event this works out if a) we have sufficient power level
to update the state and b) works out what the prev_state should be.
def annotate_state_groups(self, event, old_state=None):
yield run_on_reactor()
Returns:
Deferred: Resolved with a boolean indicating if we succesfully
updated the state.
if old_state:
event.state_group = None
event.old_state_events = {
(s.type, s.state_key): s for s in old_state
}
event.state_events = event.old_state_events
Raised:
AuthError
"""
# This needs to be done in a transaction.
if hasattr(event, "state_key"):
event.state_events[(event.type, event.state_key)] = event
if not hasattr(event, "state_key"):
defer.returnValue(False)
return
key = KeyStateTuple(
event.room_id,
event.type,
_get_state_key_from_event(event)
)
if hasattr(event, "outlier") and event.outlier:
event.state_group = None
event.old_state_events = None
event.state_events = {}
defer.returnValue(False)
return
# Now I need to fill out the prev state and work out if it has auth
# (w.r.t. to power levels)
ids = [e for e, _ in event.prev_events]
snapshot.fill_out_prev_events(event)
ret = yield self.resolve_state_groups(ids)
state_group, new_state = ret
event.prev_events = [
e for e in event.prev_events if e != event.event_id
event.old_state_events = copy.deepcopy(new_state)
if hasattr(event, "state_key"):
key = (event.type, event.state_key)
if key in new_state:
event.replaces_state = new_state[key].event_id
new_state[key] = event
elif state_group:
event.state_group = state_group
event.state_events = new_state
defer.returnValue(False)
event.state_group = None
event.state_events = new_state
defer.returnValue(hasattr(event, "state_key"))
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
events = yield self.store.get_latest_events_in_room(room_id)
event_ids = [
e_id
for e_id, _, _ in events
]
current_state = snapshot.prev_state_pdu
res = yield self.resolve_state_groups(event_ids)
if current_state:
event.prev_state = encode_event_id(
current_state.pdu_id, current_state.origin
)
# TODO check current_state to see if the min power level is less
# than the power level of the user
# power_level = self._get_power_level_for_event(event)
pdu_id, origin = decode_event_id(event.event_id, self.server_name)
yield self.store.update_current_state(
pdu_id=pdu_id,
origin=origin,
context=key.context,
pdu_type=key.type,
state_key=key.state_key
)
defer.returnValue(True)
@defer.inlineCallbacks
@log_function
def handle_new_state(self, new_pdu):
""" Apply conflict resolution to `new_pdu`.
This should be called on every new state pdu, regardless of whether or
not there is a conflict.
This function is safe against the race of it getting called with two
`PDU`s trying to update the same state.
"""
# This needs to be done in a transaction.
is_new = yield self._handle_new_state(new_pdu)
logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin)
if is_new:
yield self.store.update_current_state(
pdu_id=new_pdu.pdu_id,
origin=new_pdu.origin,
context=new_pdu.context,
pdu_type=new_pdu.pdu_type,
state_key=new_pdu.state_key
)
defer.returnValue(is_new)
def _get_power_level_for_event(self, event):
# return self._persistence.get_power_level_for_user(event.room_id,
# event.sender)
return event.power_level
@defer.inlineCallbacks
@log_function
def _handle_new_state(self, new_pdu):
tree, missing_branch = yield self.store.get_unresolved_state_tree(
new_pdu
)
new_branch, current_branch = tree
logger.debug(
"_handle_new_state new=%s, current=%s",
new_branch, current_branch
)
if missing_branch is not None:
# We're missing some PDUs. Fetch them.
# TODO (erikj): Limit this.
missing_prev = tree[missing_branch][-1]
pdu_id = missing_prev.prev_state_id
origin = missing_prev.prev_state_origin
is_missing = yield self.store.get_pdu(pdu_id, origin) is None
if not is_missing:
raise Exception("Conflict resolution failed")
yield self._replication.get_pdu(
destination=missing_prev.origin,
pdu_origin=origin,
pdu_id=pdu_id,
outlier=True
)
updated_current = yield self._handle_new_state(new_pdu)
defer.returnValue(updated_current)
if not current_branch:
# There is no current state
defer.returnValue(True)
if event_type:
defer.returnValue(res[1].get((event_type, state_key)))
return
n = new_branch[-1]
c = current_branch[-1]
defer.returnValue(res[1].values())
common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin
@defer.inlineCallbacks
@log_function
def resolve_state_groups(self, event_ids):
state_groups = yield self.store.get_state_groups(
event_ids
)
if common_ancestor:
# We found a common ancestor!
group_names = set(state_groups.keys())
if len(group_names) == 1:
name, state_list = state_groups.items().pop()
state = {
(e.type, e.state_key): e
for e in state_list
}
defer.returnValue((name, state))
if len(current_branch) == 1:
# This is a direct clobber so we can just...
defer.returnValue(True)
state = {}
for group, g_state in state_groups.items():
for s in g_state:
state.setdefault(
(s.type, s.state_key),
{}
)[s.event_id] = s
unconflicted_state = {
k: v.values()[0] for k, v in state.items()
if len(v.values()) == 1
}
conflicted_state = {
k: v.values()
for k, v in state.items()
if len(v.values()) > 1
}
try:
new_state = {}
new_state.update(unconflicted_state)
for key, events in conflicted_state.items():
new_state[key] = self._resolve_state_events(events)
except:
logger.exception("Failed to resolve state")
raise
defer.returnValue((None, new_state))
def _get_power_level_from_event_state(self, event, user_id):
if hasattr(event, "old_state_events") and event.old_state_events:
key = (RoomPowerLevelsEvent.TYPE, "", )
power_level_event = event.old_state_events.get(key)
level = None
if power_level_event:
level = power_level_event.content.get("users", {}).get(
user_id
)
if not level:
level = power_level_event.content.get("users_default", 0)
return level
else:
# We didn't find a common ancestor. This is probably fine.
pass
return 0
result = yield self._do_conflict_res(
new_branch, current_branch, common_ancestor
)
defer.returnValue(result)
@log_function
def _resolve_state_events(self, events):
curr_events = events
@defer.inlineCallbacks
def _do_conflict_res(self, new_branch, current_branch, common_ancestor):
conflict_res = [
self._do_power_level_conflict_res,
self._do_chain_length_conflict_res,
self._do_hash_conflict_res,
new_powers = [
self._get_power_level_from_event_state(e, e.user_id)
for e in curr_events
]
for algo in conflict_res:
new_res, curr_res = yield defer.maybeDeferred(
algo,
new_branch, current_branch, common_ancestor
)
new_powers = [
int(p) if p else 0 for p in new_powers
]
if new_res < curr_res:
defer.returnValue(False)
elif new_res > curr_res:
defer.returnValue(True)
max_power = max(new_powers)
raise Exception("Conflict resolution failed.")
curr_events = [
z[0] for z in zip(curr_events, new_powers)
if z[1] == max_power
]
@defer.inlineCallbacks
def _do_power_level_conflict_res(self, new_branch, current_branch,
common_ancestor):
new_powers_deferreds = []
for e in new_branch[:-1] if common_ancestor else new_branch:
if hasattr(e, "user_id"):
new_powers_deferreds.append(
self.store.get_power_level(e.context, e.user_id)
)
current_powers_deferreds = []
for e in current_branch[:-1] if common_ancestor else current_branch:
if hasattr(e, "user_id"):
current_powers_deferreds.append(
self.store.get_power_level(e.context, e.user_id)
)
new_powers = yield defer.gatherResults(
new_powers_deferreds,
consumeErrors=True
)
current_powers = yield defer.gatherResults(
current_powers_deferreds,
consumeErrors=True
)
max_power_new = max(new_powers)
max_power_current = max(current_powers)
defer.returnValue(
(max_power_new, max_power_current)
)
def _do_chain_length_conflict_res(self, new_branch, current_branch,
common_ancestor):
return (len(new_branch), len(current_branch))
def _do_hash_conflict_res(self, new_branch, current_branch,
common_ancestor):
new_str = "".join([p.pdu_id + p.origin for p in new_branch])
c_str = "".join([p.pdu_id + p.origin for p in current_branch])
if not curr_events:
raise RuntimeError("Max didn't get a max?")
elif len(curr_events) == 1:
return curr_events[0]
# TODO: For now, just choose the one with the largest event_id.
return (
hashlib.sha1(new_str).hexdigest(),
hashlib.sha1(c_str).hexdigest()
sorted(
curr_events,
key=lambda e: hashlib.sha1(
e.event_id + e.user_id + e.room_id + e.type
).hexdigest()
)[0]
)

View File

@ -16,14 +16,7 @@
from twisted.internet import defer
from synapse.api.events.room import (
RoomMemberEvent, RoomTopicEvent, FeedbackEvent,
# RoomConfigEvent,
RoomNameEvent,
RoomJoinRulesEvent,
RoomPowerLevelsEvent,
RoomAddStateLevelEvent,
RoomSendEventLevelEvent,
RoomOpsPowerLevelsEvent,
RoomMemberEvent, RoomTopicEvent, FeedbackEvent, RoomNameEvent,
RoomRedactionEvent,
)
@ -37,9 +30,17 @@ from .registration import RegistrationStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .stream import StreamStore
from .pdu import StatePduStore, PduStore, PdusTable
from .transactions import TransactionStore
from .keys import KeyStore
from .event_federation import EventFederationStore
from .state import StateStore
from .signatures import SignatureStore
from syutil.base64util import decode_base64
from synapse.crypto.event_signing import compute_event_reference_hash
import json
import logging
@ -51,7 +52,6 @@ logger = logging.getLogger(__name__)
SCHEMAS = [
"transactions",
"pdu",
"users",
"profiles",
"presence",
@ -59,6 +59,9 @@ SCHEMAS = [
"room_aliases",
"keys",
"redactions",
"state",
"event_edges",
"event_signatures",
]
@ -73,10 +76,12 @@ class _RollbackButIsFineException(Exception):
"""
pass
class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
PresenceStore, PduStore, StatePduStore, TransactionStore,
DirectoryStore, KeyStore):
PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore,
EventFederationStore, ):
def __init__(self, hs):
super(DataStore, self).__init__(hs)
@ -88,8 +93,7 @@ class DataStore(RoomMemberStore, RoomStore,
@defer.inlineCallbacks
@log_function
def persist_event(self, event=None, backfilled=False, pdu=None,
is_new_state=True):
def persist_event(self, event, backfilled=False, is_new_state=True):
stream_ordering = None
if backfilled:
if not self.min_token_deferred.called:
@ -99,8 +103,8 @@ class DataStore(RoomMemberStore, RoomStore,
try:
yield self.runInteraction(
self._persist_pdu_event_txn,
pdu=pdu,
"persist_event",
self._persist_event_txn,
event=event,
backfilled=backfilled,
stream_ordering=stream_ordering,
@ -119,7 +123,8 @@ class DataStore(RoomMemberStore, RoomStore,
"type",
"room_id",
"content",
"unrecognized_keys"
"unrecognized_keys",
"depth",
],
allow_none=allow_none,
)
@ -130,42 +135,6 @@ class DataStore(RoomMemberStore, RoomStore,
event = self._parse_event_from_row(events_dict)
defer.returnValue(event)
def _persist_pdu_event_txn(self, txn, pdu=None, event=None,
backfilled=False, stream_ordering=None,
is_new_state=True):
if pdu is not None:
self._persist_event_pdu_txn(txn, pdu)
if event is not None:
return self._persist_event_txn(
txn, event, backfilled, stream_ordering,
is_new_state=is_new_state,
)
def _persist_event_pdu_txn(self, txn, pdu):
cols = dict(pdu.__dict__)
unrec_keys = dict(pdu.unrecognized_keys)
del cols["content"]
del cols["prev_pdus"]
cols["content_json"] = json.dumps(pdu.content)
unrec_keys.update({
k: v for k, v in cols.items()
if k not in PdusTable.fields
})
cols["unrecognized_keys"] = json.dumps(unrec_keys)
cols["ts"] = cols.pop("origin_server_ts")
logger.debug("Persisting: %s", repr(cols))
if pdu.is_state:
self._persist_state_txn(txn, pdu.prev_pdus, cols)
else:
self._persist_pdu_txn(txn, pdu.prev_pdus, cols)
self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth)
@log_function
def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
is_new_state=True):
@ -177,19 +146,13 @@ class DataStore(RoomMemberStore, RoomStore,
self._store_room_name_txn(txn, event)
elif event.type == RoomTopicEvent.TYPE:
self._store_room_topic_txn(txn, event)
elif event.type == RoomJoinRulesEvent.TYPE:
self._store_join_rule(txn, event)
elif event.type == RoomPowerLevelsEvent.TYPE:
self._store_power_levels(txn, event)
elif event.type == RoomAddStateLevelEvent.TYPE:
self._store_add_state_level(txn, event)
elif event.type == RoomSendEventLevelEvent.TYPE:
self._store_send_event_level(txn, event)
elif event.type == RoomOpsPowerLevelsEvent.TYPE:
self._store_ops_level(txn, event)
elif event.type == RoomRedactionEvent.TYPE:
self._store_redaction(txn, event)
outlier = False
if hasattr(event, "outlier"):
outlier = event.outlier
vals = {
"topological_ordering": event.depth,
"event_id": event.event_id,
@ -197,25 +160,33 @@ class DataStore(RoomMemberStore, RoomStore,
"room_id": event.room_id,
"content": json.dumps(event.content),
"processed": True,
"outlier": outlier,
"depth": event.depth,
}
if stream_ordering is not None:
vals["stream_ordering"] = stream_ordering
if hasattr(event, "outlier"):
vals["outlier"] = event.outlier
else:
vals["outlier"] = False
unrec = {
k: v
for k, v in event.get_full_dict().items()
if k not in vals.keys() and k not in ["redacted", "redacted_because"]
if k not in vals.keys() and k not in [
"redacted",
"redacted_because",
"signatures",
"hashes",
"prev_events",
]
}
vals["unrecognized_keys"] = json.dumps(unrec)
try:
self._simple_insert_txn(txn, "events", vals)
self._simple_insert_txn(
txn,
"events",
vals,
or_replace=(not outlier),
)
except:
logger.warn(
"Failed to persist, probably duplicate: %s",
@ -224,6 +195,16 @@ class DataStore(RoomMemberStore, RoomStore,
)
raise _RollbackButIsFineException("_persist_event")
self._handle_prev_events(
txn,
outlier=outlier,
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
)
self._store_state_groups_txn(txn, event)
is_state = hasattr(event, "state_key") and event.state_key is not None
if is_new_state and is_state:
vals = {
@ -233,8 +214,8 @@ class DataStore(RoomMemberStore, RoomStore,
"state_key": event.state_key,
}
if hasattr(event, "prev_state"):
vals["prev_state"] = event.prev_state
if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state
self._simple_insert_txn(txn, "state_events", vals)
@ -249,6 +230,81 @@ class DataStore(RoomMemberStore, RoomStore,
}
)
for e_id, h in event.prev_state:
self._simple_insert_txn(
txn,
table="event_edges",
values={
"event_id": event.event_id,
"prev_event_id": e_id,
"room_id": event.room_id,
"is_state": 1,
},
or_ignore=True,
)
if not backfilled:
self._simple_insert_txn(
txn,
table="state_forward_extremities",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
}
)
for prev_state_id, _ in event.prev_state:
self._simple_delete_txn(
txn,
table="state_forward_extremities",
keyvalues={
"event_id": prev_state_id,
}
)
for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_event_content_hash_txn(
txn, event.event_id, hash_alg, hash_bytes,
)
if hasattr(event, "signatures"):
signatures = event.signatures.get(event.origin, {})
for key_id, signature_base64 in signatures.items():
signature_bytes = decode_base64(signature_base64)
self._store_event_origin_signature_txn(
txn, event.event_id, event.origin, key_id, signature_bytes,
)
for prev_event_id, prev_hashes in event.prev_events:
for alg, hash_base64 in prev_hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_prev_event_hash_txn(
txn, event.event_id, prev_event_id, alg, hash_bytes
)
for auth_id, _ in event.auth_events:
self._simple_insert_txn(
txn,
table="event_auth",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"auth_id": auth_id,
},
or_ignore=True,
)
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
self._store_event_reference_hash_txn(
txn, event.event_id, ref_alg, ref_hash_bytes
)
self._update_min_depth_for_room_txn(txn, event.room_id, event.depth)
def _store_redaction(self, txn, event):
txn.execute(
"INSERT OR IGNORE INTO redactions "
@ -319,7 +375,7 @@ class DataStore(RoomMemberStore, RoomStore,
],
)
def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
def snapshot_room(self, event):
"""Snapshot the room for an update by a user
Args:
room_id (synapse.types.RoomId): The room to snapshot.
@ -330,29 +386,33 @@ class DataStore(RoomMemberStore, RoomStore,
synapse.storage.Snapshot: A snapshot of the state of the room.
"""
def _snapshot(txn):
membership_state = self._get_room_member(txn, user_id, room_id)
prev_pdus = self._get_latest_pdus_in_context(
txn, room_id
prev_events = self._get_latest_events_in_room(
txn,
event.room_id
)
if state_type is not None and state_key is not None:
prev_state_pdu = self._get_current_state_pdu(
txn, room_id, state_type, state_key
prev_state = None
state_key = None
if hasattr(event, "state_key"):
state_key = event.state_key
prev_state = self._get_latest_state_in_room(
txn,
event.room_id,
type=event.type,
state_key=state_key,
)
else:
prev_state_pdu = None
return Snapshot(
store=self,
room_id=room_id,
user_id=user_id,
prev_pdus=prev_pdus,
membership_state=membership_state,
state_type=state_type,
room_id=event.room_id,
user_id=event.user_id,
prev_events=prev_events,
prev_state=prev_state,
state_type=event.type,
state_key=state_key,
prev_state_pdu=prev_state_pdu,
)
return self.runInteraction(_snapshot)
return self.runInteraction("snapshot_room", _snapshot)
class Snapshot(object):
@ -361,7 +421,7 @@ class Snapshot(object):
store (DataStore): The datastore.
room_id (RoomId): The room of the snapshot.
user_id (UserId): The user this snapshot is for.
prev_pdus (list): The list of PDU ids this snapshot is after.
prev_events (list): The list of event ids this snapshot is after.
membership_state (RoomMemberEvent): The current state of the user in
the room.
state_type (str, optional): State type captured by the snapshot
@ -370,32 +430,30 @@ class Snapshot(object):
the previous value of the state type and key in the room.
"""
def __init__(self, store, room_id, user_id, prev_pdus,
membership_state, state_type=None, state_key=None,
prev_state_pdu=None):
def __init__(self, store, room_id, user_id, prev_events,
prev_state, state_type=None, state_key=None):
self.store = store
self.room_id = room_id
self.user_id = user_id
self.prev_pdus = prev_pdus
self.membership_state = membership_state
self.prev_events = prev_events
self.prev_state = prev_state
self.state_type = state_type
self.state_key = state_key
self.prev_state_pdu = prev_state_pdu
def fill_out_prev_events(self, event):
if hasattr(event, "prev_events"):
return
if not hasattr(event, "prev_events"):
event.prev_events = [
(event_id, hashes)
for event_id, hashes, _ in self.prev_events
]
es = [
"%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus
]
if self.prev_events:
event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
else:
event.depth = 0
event.prev_events = [e for e in es if e != event.event_id]
if self.prev_pdus:
event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1
else:
event.depth = 0
if not hasattr(event, "prev_state") and self.prev_state is not None:
event.prev_state = self.prev_state
def schema_path(schema):
@ -436,11 +494,13 @@ def prepare_database(db_conn):
user_version = row[0]
if user_version > SCHEMA_VERSION:
raise ValueError("Cannot use this database as it is too " +
raise ValueError(
"Cannot use this database as it is too " +
"new for the server to understand"
)
elif user_version < SCHEMA_VERSION:
logging.info("Upgrading database from version %d",
logging.info(
"Upgrading database from version %d",
user_version
)
@ -452,13 +512,13 @@ def prepare_database(db_conn):
db_conn.commit()
else:
sql_script = "BEGIN TRANSACTION;"
sql_script = "BEGIN TRANSACTION;\n"
for sql_loc in SCHEMAS:
sql_script += read_schema(sql_loc)
sql_script += "\n"
sql_script += "COMMIT TRANSACTION;"
c.executescript(sql_script)
db_conn.commit()
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
c.close()

View File

@ -14,59 +14,69 @@
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.api.events.utils import prune_event
from synapse.util.logutils import log_function
from syutil.base64util import encode_base64
import collections
import copy
import json
import sys
import time
logger = logging.getLogger(__name__)
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging to the .execute() method."""
__slots__ = ["txn"]
__slots__ = ["txn", "name"]
def __init__(self, txn):
def __init__(self, txn, name):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
def __getattribute__(self, name):
if name == "execute":
return object.__getattribute__(self, "execute")
return getattr(object.__getattribute__(self, "txn"), name)
def __getattr__(self, name):
return getattr(self.txn, name)
def __setattr__(self, name, value):
setattr(object.__getattribute__(self, "txn"), name, value)
setattr(self.txn, name, value)
def execute(self, sql, *args, **kwargs):
# TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] %s", sql)
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
try:
if args and args[0]:
values = args[0]
sql_logger.debug("[SQL values] " +
", ".join(("<%s>",) * len(values)), *values)
sql_logger.debug(
"[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)),
self.name,
*values
)
except:
# Don't let logging failures stop SQL from working
pass
# TODO(paul): Here would be an excellent place to put some timing
# measurements, and log (warning?) slow queries.
return object.__getattribute__(self, "txn").execute(
sql, *args, **kwargs
)
start = time.clock() * 1000
try:
return self.txn.execute(
sql, *args, **kwargs
)
except:
logger.exception("[SQL FAIL] {%s}", self.name)
raise
finally:
end = time.clock() * 1000
sql_logger.debug("[SQL time] {%s} %f", self.name, end - start)
class SQLBaseStore(object):
_TXN_ID = 0
def __init__(self, hs):
self.hs = hs
@ -74,10 +84,30 @@ class SQLBaseStore(object):
self.event_factory = hs.get_event_factory()
self._clock = hs.get_clock()
def runInteraction(self, func, *args, **kwargs):
def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
def inner_func(txn, *args, **kwargs):
return func(LoggingTransaction(txn), *args, **kwargs)
start = time.clock() * 1000
txn_id = SQLBaseStore._TXN_ID
# We don't really need these to be unique, so lets stop it from
# growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
name = "%s-%x" % (desc, txn_id, )
transaction_logger.debug("[TXN START] {%s}", name)
try:
return func(LoggingTransaction(txn, name), *args, **kwargs)
except:
logger.exception("[TXN FAIL] {%s}", name)
raise
finally:
end = time.clock() * 1000
transaction_logger.debug(
"[TXN END] {%s} %f",
name, end - start
)
return self._db_pool.runInteraction(inner_func, *args, **kwargs)
@ -113,7 +143,7 @@ class SQLBaseStore(object):
else:
return cursor.fetchall()
return self.runInteraction(interaction)
return self.runInteraction("_execute", interaction)
def _execute_and_decode(self, query, *args):
return self._execute(self.cursor_to_dict, query, *args)
@ -130,6 +160,7 @@ class SQLBaseStore(object):
or_replace : bool; if True performs an INSERT OR REPLACE
"""
return self.runInteraction(
"_simple_insert",
self._simple_insert_txn, table, values, or_replace=or_replace,
or_ignore=or_ignore,
)
@ -170,7 +201,6 @@ class SQLBaseStore(object):
table, keyvalues, retcols=retcols, allow_none=allow_none
)
@defer.inlineCallbacks
def _simple_select_one_onecol(self, table, keyvalues, retcol,
allow_none=False):
"""Executes a SELECT query on the named table, which is expected to
@ -181,19 +211,40 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the row with
retcol : string giving the name of the column to return
"""
ret = yield self._simple_select_one(
return self.runInteraction(
"_simple_select_one_onecol",
self._simple_select_one_onecol_txn,
table, keyvalues, retcol, allow_none=allow_none,
)
def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
allow_none=False):
ret = self._simple_select_onecol_txn(
txn,
table=table,
keyvalues=keyvalues,
retcols=[retcol],
allow_none=allow_none
retcol=retcol,
)
if ret:
defer.returnValue(ret[retcol])
return ret[0]
else:
defer.returnValue(None)
if allow_none:
return None
else:
raise StoreError(404, "No row found")
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
"retcol": retcol,
"table": table,
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
}
txn.execute(sql, keyvalues.values())
return [r[0] for r in txn.fetchall()]
@defer.inlineCallbacks
def _simple_select_onecol(self, table, keyvalues, retcol):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@ -206,19 +257,11 @@ class SQLBaseStore(object):
Returns:
Deferred: Results in a list
"""
sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
"retcol": retcol,
"table": table,
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
}
def func(txn):
txn.execute(sql, keyvalues.values())
return txn.fetchall()
res = yield self.runInteraction(func)
defer.returnValue([r[0] for r in res])
return self.runInteraction(
"_simple_select_onecol",
self._simple_select_onecol_txn,
table, keyvalues, retcol
)
def _simple_select_list(self, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
@ -229,17 +272,30 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
return self.runInteraction(
"_simple_select_list",
self._simple_select_list_txn,
table, keyvalues, retcols
)
def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
txn : Transaction object
table : string giving the table name
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k) for k in keyvalues)
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
def func(txn):
txn.execute(sql, keyvalues.values())
return self.cursor_to_dict(txn)
return self.runInteraction(func)
txn.execute(sql, keyvalues.values())
return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
retcols=None):
@ -307,7 +363,7 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched")
return ret
return self.runInteraction(func)
return self.runInteraction("_simple_selectupdate_one", func)
def _simple_delete_one(self, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a
@ -319,7 +375,7 @@ class SQLBaseStore(object):
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k) for k in keyvalues)
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
def func(txn):
@ -328,7 +384,25 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
return self.runInteraction(func)
return self.runInteraction("_simple_delete_one", func)
def _simple_delete(self, table, keyvalues):
"""Executes a DELETE query on the named table.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
return self.runInteraction("_simple_delete", self._simple_delete_txn)
def _simple_delete_txn(self, txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
return txn.execute(sql, keyvalues.values())
def _simple_max_id(self, table):
"""Executes a SELECT query on the named table, expecting to return the
@ -346,7 +420,7 @@ class SQLBaseStore(object):
return 0
return max_id
return self.runInteraction(func)
return self.runInteraction("_simple_max_id", func)
def _parse_event_from_row(self, row_dict):
d = copy.deepcopy({k: v for k, v in row_dict.items()})
@ -355,6 +429,10 @@ class SQLBaseStore(object):
d.pop("topological_ordering", None)
d.pop("processed", None)
d["origin_server_ts"] = d.pop("ts", 0)
replaces_state = d.pop("prev_state", None)
if replaces_state:
d["replaces_state"] = replaces_state
d.update(json.loads(row_dict["unrecognized_keys"]))
d["content"] = json.loads(d["content"])
@ -369,23 +447,65 @@ class SQLBaseStore(object):
**d
)
def _get_events_txn(self, txn, event_ids):
# FIXME (erikj): This should be batched?
sql = "SELECT * FROM events WHERE event_id = ?"
event_rows = []
for e_id in event_ids:
c = txn.execute(sql, (e_id,))
event_rows.extend(self.cursor_to_dict(c))
return self._parse_events_txn(txn, event_rows)
def _parse_events(self, rows):
return self.runInteraction(self._parse_events_txn, rows)
return self.runInteraction(
"_parse_events", self._parse_events_txn, rows
)
def _parse_events_txn(self, txn, rows):
events = [self._parse_event_from_row(r) for r in rows]
sql = "SELECT * FROM events WHERE event_id = ?"
select_event_sql = "SELECT * FROM events WHERE event_id = ?"
for ev in events:
if hasattr(ev, "prev_state"):
# Load previous state_content.
# TODO: Should we be pulling this out above?
cursor = txn.execute(sql, (ev.prev_state,))
prevs = self.cursor_to_dict(cursor)
if prevs:
prev = self._parse_event_from_row(prevs[0])
ev.prev_content = prev.content
for i, ev in enumerate(events):
signatures = self._get_event_origin_signatures_txn(
txn, ev.event_id,
)
ev.signatures = {
k: encode_base64(v) for k, v in signatures.items()
}
prevs = self._get_prev_events_and_state(txn, ev.event_id)
ev.prev_events = [
(e_id, h)
for e_id, h, is_state in prevs
if is_state == 0
]
ev.auth_events = self._get_auth_events(txn, ev.event_id)
if hasattr(ev, "state_key"):
ev.prev_state = [
(e_id, h)
for e_id, h, is_state in prevs
if is_state == 1
]
if hasattr(ev, "replaces_state"):
# Load previous state_content.
# FIXME (erikj): Handle multiple prev_states.
cursor = txn.execute(
select_event_sql,
(ev.replaces_state,)
)
prevs = self.cursor_to_dict(cursor)
if prevs:
prev = self._parse_event_from_row(prevs[0])
ev.prev_content = prev.content
if not hasattr(ev, "redacted"):
logger.debug("Doesn't have redacted key: %s", ev)
@ -393,15 +513,16 @@ class SQLBaseStore(object):
if ev.redacted:
# Get the redaction event.
sql = "SELECT * FROM events WHERE event_id = ?"
txn.execute(sql, (ev.redacted,))
select_event_sql = "SELECT * FROM events WHERE event_id = ?"
txn.execute(select_event_sql, (ev.redacted,))
del_evs = self._parse_events_txn(
txn, self.cursor_to_dict(txn)
)
if del_evs:
prune_event(ev)
ev = prune_event(ev)
events[i] = ev
ev.redacted_because = del_evs[0]
return events

View File

@ -95,6 +95,7 @@ class DirectoryStore(SQLBaseStore):
def delete_room_alias(self, room_alias):
return self.runInteraction(
"delete_room_alias",
self._delete_room_alias_txn,
room_alias,
)

View File

@ -0,0 +1,377 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from syutil.base64util import encode_base64
import logging
logger = logging.getLogger(__name__)
class EventFederationStore(SQLBaseStore):
def get_auth_chain(self, event_id):
return self.runInteraction(
"get_auth_chain",
self._get_auth_chain_txn,
event_id
)
def _get_auth_chain_txn(self, txn, event_id):
results = self._get_auth_chain_ids_txn(txn, event_id)
sql = "SELECT * FROM events WHERE event_id = ?"
rows = []
for ev_id in results:
c = txn.execute(sql, (ev_id,))
rows.extend(self.cursor_to_dict(c))
return self._parse_events_txn(txn, rows)
def get_auth_chain_ids(self, event_id):
return self.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
event_id
)
def _get_auth_chain_ids_txn(self, txn, event_id):
results = set()
base_sql = (
"SELECT auth_id FROM event_auth WHERE %s"
)
front = set([event_id])
while front:
sql = base_sql % (
" OR ".join(["event_id=?"] * len(front)),
)
txn.execute(sql, list(front))
front = [r[0] for r in txn.fetchall()]
results.update(front)
return list(results)
def get_oldest_events_in_room(self, room_id):
return self.runInteraction(
"get_oldest_events_in_room",
self._get_oldest_events_in_room_txn,
room_id,
)
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn(
txn,
table="event_backward_extremities",
keyvalues={
"room_id": room_id,
},
retcol="event_id",
)
def get_latest_events_in_room(self, room_id):
return self.runInteraction(
"get_latest_events_in_room",
self._get_latest_events_in_room,
room_id,
)
def _get_latest_events_in_room(self, txn, room_id):
sql = (
"SELECT e.event_id, e.depth FROM events as e "
"INNER JOIN event_forward_extremities as f "
"ON e.event_id = f.event_id "
"WHERE f.room_id = ?"
)
txn.execute(sql, (room_id, ))
results = []
for event_id, depth in txn.fetchall():
hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
if k == "sha256"
}
results.append((event_id, prev_hashes, depth))
return results
def _get_latest_state_in_room(self, txn, room_id, type, state_key):
event_ids = self._simple_select_onecol_txn(
txn,
table="state_forward_extremities",
keyvalues={
"room_id": room_id,
"type": type,
"state_key": state_key,
},
retcol="event_id",
)
results = []
for event_id in event_ids:
hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
if k == "sha256"
}
results.append((event_id, prev_hashes))
return results
def _get_prev_events(self, txn, event_id):
results = self._get_prev_events_and_state(
txn,
event_id,
is_state=0,
)
return [(e_id, h, ) for e_id, h, _ in results]
def _get_prev_state(self, txn, event_id):
results = self._get_prev_events_and_state(
txn,
event_id,
is_state=1,
)
return [(e_id, h, ) for e_id, h, _ in results]
def _get_prev_events_and_state(self, txn, event_id, is_state=None):
keyvalues = {
"event_id": event_id,
}
if is_state is not None:
keyvalues["is_state"] = is_state
res = self._simple_select_list_txn(
txn,
table="event_edges",
keyvalues=keyvalues,
retcols=["prev_event_id", "is_state"],
)
results = []
for d in res:
hashes = self._get_event_reference_hashes_txn(
txn,
d["prev_event_id"]
)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
if k == "sha256"
}
results.append((d["prev_event_id"], prev_hashes, d["is_state"]))
return results
def _get_auth_events(self, txn, event_id):
auth_ids = self._simple_select_onecol_txn(
txn,
table="event_auth",
keyvalues={
"event_id": event_id,
},
retcol="auth_id",
)
results = []
for auth_id in auth_ids:
hashes = self._get_event_reference_hashes_txn(txn, auth_id)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
if k == "sha256"
}
results.append((auth_id, prev_hashes))
return results
def get_min_depth(self, room_id):
return self.runInteraction(
"get_min_depth",
self._get_min_depth_interaction,
room_id,
)
def _get_min_depth_interaction(self, txn, room_id):
min_depth = self._simple_select_one_onecol_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
retcol="min_depth",
allow_none=True,
)
return int(min_depth) if min_depth is not None else None
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
min_depth = self._get_min_depth_interaction(txn, room_id)
do_insert = depth < min_depth if min_depth else True
if do_insert:
self._simple_insert_txn(
txn,
table="room_depth",
values={
"room_id": room_id,
"min_depth": depth,
},
or_replace=True,
)
def _handle_prev_events(self, txn, outlier, event_id, prev_events,
room_id):
for e_id, _ in prev_events:
# TODO (erikj): This could be done as a bulk insert
self._simple_insert_txn(
txn,
table="event_edges",
values={
"event_id": event_id,
"prev_event_id": e_id,
"room_id": room_id,
"is_state": 0,
},
or_ignore=True,
)
# Update the extremities table if this is not an outlier.
if not outlier:
for e_id, _ in prev_events:
# TODO (erikj): This could be done as a bulk insert
self._simple_delete_txn(
txn,
table="event_forward_extremities",
keyvalues={
"event_id": e_id,
"room_id": room_id,
}
)
# We only insert as a forward extremity the new pdu if there are
# no other pdus that reference it as a prev pdu
query = (
"INSERT OR IGNORE INTO %(table)s (event_id, room_id) "
"SELECT ?, ? WHERE NOT EXISTS ("
"SELECT 1 FROM %(event_edges)s WHERE "
"prev_event_id = ? "
")"
) % {
"table": "event_forward_extremities",
"event_edges": "event_edges",
}
logger.debug("query: %s", query)
txn.execute(query, (event_id, room_id, event_id))
# Insert all the prev_pdus as a backwards thing, they'll get
# deleted in a second if they're incorrect anyway.
for e_id, _ in prev_events:
# TODO (erikj): This could be done as a bulk insert
self._simple_insert_txn(
txn,
table="event_backward_extremities",
values={
"event_id": e_id,
"room_id": room_id,
},
or_ignore=True,
)
# Also delete from the backwards extremities table all ones that
# reference pdus that we have already seen
query = (
"DELETE FROM event_backward_extremities WHERE EXISTS ("
"SELECT 1 FROM events "
"WHERE "
"event_backward_extremities.event_id = events.event_id "
"AND not events.outlier "
")"
)
txn.execute(query)
def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occured before (and
including) the pdus in pdu_list. Return a list of max size `limit`.
Args:
txn
room_id (str)
event_list (list)
limit (int)
Return:
list: A list of PduTuples
"""
return self.runInteraction(
"get_backfill_events",
self._get_backfill_events, room_id, event_list, limit
)
def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug(
"_get_backfill_events: %s, %s, %s",
room_id, repr(event_list), limit
)
# We seed the pdu_results with the things from the pdu_list.
event_results = event_list
front = event_list
query = (
"SELECT prev_event_id FROM event_edges "
"WHERE room_id = ? AND event_id = ? "
"LIMIT ?"
)
# We iterate through all event_ids in `front` to select their previous
# events. These are dumped in `new_front`.
# We continue until we reach the limit *or* new_front is empty (i.e.,
# we've run out of things to select
while front and len(event_results) < limit:
new_front = []
for event_id in front:
logger.debug(
"_backfill_interaction: id=%s",
event_id
)
txn.execute(
query,
(room_id, event_id, limit - len(event_results))
)
for row in txn.fetchall():
logger.debug(
"_backfill_interaction: got id=%s",
*row
)
new_front.append(row[0])
front = new_front
event_results += new_front
# We also want to update the `prev_pdus` attributes before returning.
return self._get_events_txn(txn, event_results)

View File

@ -1,915 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore, Table, JoinHelper
from synapse.federation.units import Pdu
from synapse.util.logutils import log_function
from collections import namedtuple
import logging
logger = logging.getLogger(__name__)
class PduStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
def get_pdu(self, pdu_id, origin):
"""Given a pdu_id and origin, get a PDU.
Args:
txn
pdu_id (str)
origin (str)
Returns:
PduTuple: If the pdu does not exist in the database, returns None
"""
return self.runInteraction(
self._get_pdu_tuple, pdu_id, origin
)
def _get_pdu_tuple(self, txn, pdu_id, origin):
res = self._get_pdu_tuples(txn, [(pdu_id, origin)])
return res[0] if res else None
def _get_pdu_tuples(self, txn, pdu_id_tuples):
results = []
for pdu_id, origin in pdu_id_tuples:
txn.execute(
PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"),
(pdu_id, origin)
)
edges = [
(r.prev_pdu_id, r.prev_origin)
for r in PduEdgesTable.decode_results(txn.fetchall())
]
query = (
"SELECT %(fields)s FROM %(pdus)s as p "
"LEFT JOIN %(state)s as s "
"ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
"WHERE p.pdu_id = ? AND p.origin = ? "
) % {
"fields": _pdu_state_joiner.get_fields(
PdusTable="p", StatePdusTable="s"),
"pdus": PdusTable.table_name,
"state": StatePdusTable.table_name,
}
txn.execute(query, (pdu_id, origin))
row = txn.fetchone()
if row:
results.append(PduTuple(PduEntry(*row), edges))
return results
def get_current_state_for_context(self, context):
"""Get a list of PDUs that represent the current state for a given
context
Args:
context (str)
Returns:
list: A list of PduTuples
"""
return self.runInteraction(
self._get_current_state_for_context,
context
)
def _get_current_state_for_context(self, txn, context):
query = (
"SELECT pdu_id, origin FROM %s WHERE context = ?"
% CurrentStateTable.table_name
)
logger.debug("get_current_state %s, Args=%s", query, context)
txn.execute(query, (context,))
res = txn.fetchall()
logger.debug("get_current_state %d results", len(res))
return self._get_pdu_tuples(txn, res)
def _persist_pdu_txn(self, txn, prev_pdus, cols):
"""Inserts a (non-state) PDU into the database.
Args:
txn,
prev_pdus (list)
**cols: The columns to insert into the PdusTable.
"""
entry = PdusTable.EntryType(
**{k: cols.get(k, None) for k in PdusTable.fields}
)
txn.execute(PdusTable.insert_statement(), entry)
self._handle_prev_pdus(
txn, entry.outlier, entry.pdu_id, entry.origin,
prev_pdus, entry.context
)
def mark_pdu_as_processed(self, pdu_id, pdu_origin):
"""Mark a received PDU as processed.
Args:
txn
pdu_id (str)
pdu_origin (str)
"""
return self.runInteraction(
self._mark_as_processed, pdu_id, pdu_origin
)
def _mark_as_processed(self, txn, pdu_id, pdu_origin):
txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name)
def get_all_pdus_from_context(self, context):
"""Get a list of all PDUs for a given context."""
return self.runInteraction(
self._get_all_pdus_from_context, context,
)
def _get_all_pdus_from_context(self, txn, context):
query = (
"SELECT pdu_id, origin FROM %s "
"WHERE context = ?"
) % PdusTable.table_name
txn.execute(query, (context,))
return self._get_pdu_tuples(txn, txn.fetchall())
def get_backfill(self, context, pdu_list, limit):
"""Get a list of Pdus for a given topic that occured before (and
including) the pdus in pdu_list. Return a list of max size `limit`.
Args:
txn
context (str)
pdu_list (list)
limit (int)
Return:
list: A list of PduTuples
"""
return self.runInteraction(
self._get_backfill, context, pdu_list, limit
)
def _get_backfill(self, txn, context, pdu_list, limit):
logger.debug(
"backfill: %s, %s, %s",
context, repr(pdu_list), limit
)
# We seed the pdu_results with the things from the pdu_list.
pdu_results = pdu_list
front = pdu_list
query = (
"SELECT prev_pdu_id, prev_origin FROM %(edges_table)s "
"WHERE context = ? AND pdu_id = ? AND origin = ? "
"LIMIT ?"
) % {
"edges_table": PduEdgesTable.table_name,
}
# We iterate through all pdu_ids in `front` to select their previous
# pdus. These are dumped in `new_front`. We continue until we reach the
# limit *or* new_front is empty (i.e., we've run out of things to
# select
while front and len(pdu_results) < limit:
new_front = []
for pdu_id, origin in front:
logger.debug(
"_backfill_interaction: i=%s, o=%s",
pdu_id, origin
)
txn.execute(
query,
(context, pdu_id, origin, limit - len(pdu_results))
)
for row in txn.fetchall():
logger.debug(
"_backfill_interaction: got i=%s, o=%s",
*row
)
new_front.append(row)
front = new_front
pdu_results += new_front
# We also want to update the `prev_pdus` attributes before returning.
return self._get_pdu_tuples(txn, pdu_results)
def get_min_depth_for_context(self, context):
"""Get the current minimum depth for a context
Args:
txn
context (str)
"""
return self.runInteraction(
self._get_min_depth_for_context, context
)
def _get_min_depth_for_context(self, txn, context):
return self._get_min_depth_interaction(txn, context)
def _get_min_depth_interaction(self, txn, context):
txn.execute(
"SELECT min_depth FROM %s WHERE context = ?"
% ContextDepthTable.table_name,
(context,)
)
row = txn.fetchone()
return row[0] if row else None
def _update_min_depth_for_context_txn(self, txn, context, depth):
"""Update the minimum `depth` of the given context, which is the line
on which we stop backfilling backwards.
Args:
context (str)
depth (int)
"""
min_depth = self._get_min_depth_interaction(txn, context)
do_insert = depth < min_depth if min_depth else True
if do_insert:
txn.execute(
"INSERT OR REPLACE INTO %s (context, min_depth) "
"VALUES (?,?)" % ContextDepthTable.table_name,
(context, depth)
)
def _get_latest_pdus_in_context(self, txn, context):
"""Get's a list of the most current pdus for a given context. This is
used when we are sending a Pdu and need to fill out the `prev_pdus`
key
Args:
txn
context
"""
query = (
"SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p "
"INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id "
"AND f.origin = p.origin "
"WHERE f.context = ?"
) % {
"pdus": PdusTable.table_name,
"forward": PduForwardExtremitiesTable.table_name,
}
logger.debug("get_prev query: %s", query)
txn.execute(
query,
(context, )
)
results = txn.fetchall()
return [(row[0], row[1], row[2]) for row in results]
@defer.inlineCallbacks
def get_oldest_pdus_in_context(self, context):
"""Get a list of Pdus that we haven't backfilled beyond yet (and havent
seen). This list is used when we want to backfill backwards and is the
list we send to the remote server.
Args:
txn
context (str)
Returns:
list: A list of PduIdTuple.
"""
results = yield self._execute(
None,
"SELECT pdu_id, origin FROM %(back)s WHERE context = ?"
% {"back": PduBackwardExtremitiesTable.table_name, },
context
)
defer.returnValue([PduIdTuple(i, o) for i, o in results])
def is_pdu_new(self, pdu_id, origin, context, depth):
"""For a given Pdu, try and figure out if it's 'new', i.e., if it's
not something we got randomly from the past, for example when we
request the current state of the room that will probably return a bunch
of pdus from before we joined.
Args:
txn
pdu_id (str)
origin (str)
context (str)
depth (int)
Returns:
bool
"""
return self.runInteraction(
self._is_pdu_new,
pdu_id=pdu_id,
origin=origin,
context=context,
depth=depth
)
def _is_pdu_new(self, txn, pdu_id, origin, context, depth):
# If depth > min depth in back table, then we classify it as new.
# OR if there is nothing in the back table, then it kinda needs to
# be a new thing.
query = (
"SELECT min(p.depth) FROM %(edges)s as e "
"INNER JOIN %(back)s as b "
"ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin "
"INNER JOIN %(pdus)s as p "
"ON e.pdu_id = p.pdu_id AND p.origin = e.origin "
"WHERE p.context = ?"
) % {
"pdus": PdusTable.table_name,
"edges": PduEdgesTable.table_name,
"back": PduBackwardExtremitiesTable.table_name,
}
txn.execute(query, (context,))
min_depth, = txn.fetchone()
if not min_depth or depth > int(min_depth):
logger.debug(
"is_new true: id=%s, o=%s, d=%s min_depth=%s",
pdu_id, origin, depth, min_depth
)
return True
# If this pdu is in the forwards table, then it also is a new one
query = (
"SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?"
) % {
"forward": PduForwardExtremitiesTable.table_name,
}
txn.execute(query, (pdu_id, origin))
# Did we get anything?
if txn.fetchall():
logger.debug(
"is_new true: id=%s, o=%s, d=%s was forward",
pdu_id, origin, depth
)
return True
logger.debug(
"is_new false: id=%s, o=%s, d=%s",
pdu_id, origin, depth
)
# FINE THEN. It's probably old.
return False
@staticmethod
@log_function
def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus,
context):
txn.executemany(
PduEdgesTable.insert_statement(),
[(pdu_id, origin, p[0], p[1], context) for p in prev_pdus]
)
# Update the extremities table if this is not an outlier.
if not outlier:
# First, we delete the new one from the forwards extremities table.
query = (
"DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
% PduForwardExtremitiesTable.table_name
)
txn.executemany(query, prev_pdus)
# We only insert as a forward extremety the new pdu if there are no
# other pdus that reference it as a prev pdu
query = (
"INSERT INTO %(table)s (pdu_id, origin, context) "
"SELECT ?, ?, ? WHERE NOT EXISTS ("
"SELECT 1 FROM %(pdu_edges)s WHERE "
"prev_pdu_id = ? AND prev_origin = ?"
")"
) % {
"table": PduForwardExtremitiesTable.table_name,
"pdu_edges": PduEdgesTable.table_name
}
logger.debug("query: %s", query)
txn.execute(query, (pdu_id, origin, context, pdu_id, origin))
# Insert all the prev_pdus as a backwards thing, they'll get
# deleted in a second if they're incorrect anyway.
txn.executemany(
PduBackwardExtremitiesTable.insert_statement(),
[(i, o, context) for i, o in prev_pdus]
)
# Also delete from the backwards extremities table all ones that
# reference pdus that we have already seen
query = (
"DELETE FROM %(pdu_back)s WHERE EXISTS ("
"SELECT 1 FROM %(pdus)s AS pdus "
"WHERE "
"%(pdu_back)s.pdu_id = pdus.pdu_id "
"AND %(pdu_back)s.origin = pdus.origin "
"AND not pdus.outlier "
")"
) % {
"pdu_back": PduBackwardExtremitiesTable.table_name,
"pdus": PdusTable.table_name,
}
txn.execute(query)
class StatePduStore(SQLBaseStore):
"""A collection of queries for handling state PDUs.
"""
def _persist_state_txn(self, txn, prev_pdus, cols):
"""Inserts a state PDU into the database
Args:
txn,
prev_pdus (list)
**cols: The columns to insert into the PdusTable and StatePdusTable
"""
pdu_entry = PdusTable.EntryType(
**{k: cols.get(k, None) for k in PdusTable.fields}
)
state_entry = StatePdusTable.EntryType(
**{k: cols.get(k, None) for k in StatePdusTable.fields}
)
logger.debug("Inserting pdu: %s", repr(pdu_entry))
logger.debug("Inserting state: %s", repr(state_entry))
txn.execute(PdusTable.insert_statement(), pdu_entry)
txn.execute(StatePdusTable.insert_statement(), state_entry)
self._handle_prev_pdus(
txn,
pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus,
pdu_entry.context
)
def get_unresolved_state_tree(self, new_state_pdu):
return self.runInteraction(
self._get_unresolved_state_tree, new_state_pdu
)
@log_function
def _get_unresolved_state_tree(self, txn, new_pdu):
current = self._get_current_interaction(
txn,
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
)
ReturnType = namedtuple(
"StateReturnType", ["new_branch", "current_branch"]
)
return_value = ReturnType([new_pdu], [])
if not current:
logger.debug("get_unresolved_state_tree No current state.")
return (return_value, None)
return_value.current_branch.append(current)
enum_branches = self._enumerate_state_branches(
txn, new_pdu, current
)
missing_branch = None
for branch, prev_state, state in enum_branches:
if state:
return_value[branch].append(state)
else:
# We don't have prev_state :(
missing_branch = branch
break
return (return_value, missing_branch)
def update_current_state(self, pdu_id, origin, context, pdu_type,
state_key):
return self.runInteraction(
self._update_current_state,
pdu_id, origin, context, pdu_type, state_key
)
def _update_current_state(self, txn, pdu_id, origin, context, pdu_type,
state_key):
query = (
"INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
) % {
"curr": CurrentStateTable.table_name,
"fields": CurrentStateTable.get_fields_string(),
"qs": ", ".join(["?"] * len(CurrentStateTable.fields))
}
query_args = CurrentStateTable.EntryType(
pdu_id=pdu_id,
origin=origin,
context=context,
pdu_type=pdu_type,
state_key=state_key
)
txn.execute(query, query_args)
def get_current_state_pdu(self, context, pdu_type, state_key):
"""For a given context, pdu_type, state_key 3-tuple, return what is
currently considered the current state.
Args:
txn
context (str)
pdu_type (str)
state_key (str)
Returns:
PduEntry
"""
return self.runInteraction(
self._get_current_state_pdu, context, pdu_type, state_key
)
def _get_current_state_pdu(self, txn, context, pdu_type, state_key):
return self._get_current_interaction(txn, context, pdu_type, state_key)
def _get_current_interaction(self, txn, context, pdu_type, state_key):
logger.debug(
"_get_current_interaction %s %s %s",
context, pdu_type, state_key
)
fields = _pdu_state_joiner.get_fields(
PdusTable="p", StatePdusTable="s")
current_query = (
"SELECT %(fields)s FROM %(state)s as s "
"INNER JOIN %(pdus)s as p "
"ON s.pdu_id = p.pdu_id AND s.origin = p.origin "
"INNER JOIN %(curr)s as c "
"ON s.pdu_id = c.pdu_id AND s.origin = c.origin "
"WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? "
) % {
"fields": fields,
"curr": CurrentStateTable.table_name,
"state": StatePdusTable.table_name,
"pdus": PdusTable.table_name,
}
txn.execute(
current_query,
(context, pdu_type, state_key)
)
row = txn.fetchone()
result = PduEntry(*row) if row else None
if not result:
logger.debug("_get_current_interaction not found")
else:
logger.debug(
"_get_current_interaction found %s %s",
result.pdu_id, result.origin
)
return result
def handle_new_state(self, new_pdu):
"""Actually perform conflict resolution on the new_pdu on the
assumption we have all the pdus required to perform it.
Args:
new_pdu
Returns:
bool: True if the new_pdu clobbered the current state, False if not
"""
return self.runInteraction(
self._handle_new_state, new_pdu
)
def _handle_new_state(self, txn, new_pdu):
logger.debug(
"handle_new_state %s %s",
new_pdu.pdu_id, new_pdu.origin
)
current = self._get_current_interaction(
txn,
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
)
is_current = False
if (not current or not current.prev_state_id
or not current.prev_state_origin):
# Oh, we don't have any state for this yet.
is_current = True
elif (current.pdu_id == new_pdu.prev_state_id
and current.origin == new_pdu.prev_state_origin):
# Oh! A direct clobber. Just do it.
is_current = True
else:
##
# Ok, now loop through until we get to a common ancestor.
max_new = int(new_pdu.power_level)
max_current = int(current.power_level)
enum_branches = self._enumerate_state_branches(
txn, new_pdu, current
)
for branch, prev_state, state in enum_branches:
if not state:
raise RuntimeError(
"Could not find state_pdu %s %s" %
(
prev_state.prev_state_id,
prev_state.prev_state_origin
)
)
if branch == 0:
max_new = max(int(state.depth), max_new)
else:
max_current = max(int(state.depth), max_current)
is_current = max_new > max_current
if is_current:
logger.debug("handle_new_state make current")
# Right, this is a new thing, so woo, just insert it.
txn.execute(
"INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
% {
"curr": CurrentStateTable.table_name,
"fields": CurrentStateTable.get_fields_string(),
"qs": ", ".join(["?"] * len(CurrentStateTable.fields))
},
CurrentStateTable.EntryType(
*(new_pdu.__dict__[k] for k in CurrentStateTable.fields)
)
)
else:
logger.debug("handle_new_state not current")
logger.debug("handle_new_state done")
return is_current
@log_function
def _enumerate_state_branches(self, txn, pdu_a, pdu_b):
branch_a = pdu_a
branch_b = pdu_b
while True:
if (branch_a.pdu_id == branch_b.pdu_id
and branch_a.origin == branch_b.origin):
# Woo! We found a common ancestor
logger.debug("_enumerate_state_branches Found common ancestor")
break
do_branch_a = (
hasattr(branch_a, "prev_state_id") and
branch_a.prev_state_id
)
do_branch_b = (
hasattr(branch_b, "prev_state_id") and
branch_b.prev_state_id
)
logger.debug(
"do_branch_a=%s, do_branch_b=%s",
do_branch_a, do_branch_b
)
if do_branch_a and do_branch_b:
do_branch_a = int(branch_a.depth) > int(branch_b.depth)
if do_branch_a:
pdu_tuple = PduIdTuple(
branch_a.prev_state_id,
branch_a.prev_state_origin
)
prev_branch = branch_a
logger.debug("getting branch_a prev %s", pdu_tuple)
branch_a = self._get_pdu_tuple(txn, *pdu_tuple)
if branch_a:
branch_a = Pdu.from_pdu_tuple(branch_a)
logger.debug("branch_a=%s", branch_a)
yield (0, prev_branch, branch_a)
if not branch_a:
break
elif do_branch_b:
pdu_tuple = PduIdTuple(
branch_b.prev_state_id,
branch_b.prev_state_origin
)
prev_branch = branch_b
logger.debug("getting branch_b prev %s", pdu_tuple)
branch_b = self._get_pdu_tuple(txn, *pdu_tuple)
if branch_b:
branch_b = Pdu.from_pdu_tuple(branch_b)
logger.debug("branch_b=%s", branch_b)
yield (1, prev_branch, branch_b)
if not branch_b:
break
else:
break
class PdusTable(Table):
table_name = "pdus"
fields = [
"pdu_id",
"origin",
"context",
"pdu_type",
"ts",
"depth",
"is_state",
"content_json",
"unrecognized_keys",
"outlier",
"have_processed",
]
EntryType = namedtuple("PdusEntry", fields)
class PduDestinationsTable(Table):
table_name = "pdu_destinations"
fields = [
"pdu_id",
"origin",
"destination",
"delivered_ts",
]
EntryType = namedtuple("PduDestinationsEntry", fields)
class PduEdgesTable(Table):
table_name = "pdu_edges"
fields = [
"pdu_id",
"origin",
"prev_pdu_id",
"prev_origin",
"context"
]
EntryType = namedtuple("PduEdgesEntry", fields)
class PduForwardExtremitiesTable(Table):
table_name = "pdu_forward_extremities"
fields = [
"pdu_id",
"origin",
"context",
]
EntryType = namedtuple("PduForwardExtremitiesEntry", fields)
class PduBackwardExtremitiesTable(Table):
table_name = "pdu_backward_extremities"
fields = [
"pdu_id",
"origin",
"context",
]
EntryType = namedtuple("PduBackwardExtremitiesEntry", fields)
class ContextDepthTable(Table):
table_name = "context_depth"
fields = [
"context",
"min_depth",
]
EntryType = namedtuple("ContextDepthEntry", fields)
class StatePdusTable(Table):
table_name = "state_pdus"
fields = [
"pdu_id",
"origin",
"context",
"pdu_type",
"state_key",
"power_level",
"prev_state_id",
"prev_state_origin",
]
EntryType = namedtuple("StatePdusEntry", fields)
class CurrentStateTable(Table):
table_name = "current_state"
fields = [
"pdu_id",
"origin",
"context",
"pdu_type",
"state_key",
]
EntryType = namedtuple("CurrentStateEntry", fields)
_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable)
# TODO: These should probably be put somewhere more sensible
PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin"))
PduEntry = _pdu_state_joiner.EntryType
""" We are always interested in the join of the PdusTable and StatePdusTable,
rather than just the PdusTable.
This does not include a prev_pdus key.
"""
PduTuple = namedtuple(
"PduTuple",
("pdu_entry", "prev_pdu_list")
)
""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
the `prev_pdus` key of a PDU.
"""

View File

@ -62,8 +62,10 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if the user_id could not be registered.
"""
yield self.runInteraction(self._register, user_id, token,
password_hash)
yield self.runInteraction(
"register",
self._register, user_id, token, password_hash
)
def _register(self, txn, user_id, token, password_hash):
now = int(self.clock.time())
@ -100,17 +102,22 @@ class RegistrationStore(SQLBaseStore):
StoreError if no user was found.
"""
return self.runInteraction(
"get_user_by_token",
self._query_for_auth,
token
)
@defer.inlineCallbacks
def is_server_admin(self, user):
return self._simple_select_one_onecol(
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
allow_none=True,
)
defer.returnValue(res if res else False)
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.admin, access_tokens.device_id "

View File

@ -132,209 +132,29 @@ class RoomStore(SQLBaseStore):
defer.returnValue(ret)
@defer.inlineCallbacks
def get_room_join_rule(self, room_id):
sql = (
"SELECT join_rule FROM room_join_rules as r "
"INNER JOIN current_state_events as c "
"ON r.event_id = c.event_id "
"WHERE c.room_id = ? "
)
rows = yield self._execute(None, sql, room_id)
if len(rows) == 1:
defer.returnValue(rows[0][0])
else:
defer.returnValue(None)
def get_power_level(self, room_id, user_id):
return self.runInteraction(
self._get_power_level,
room_id, user_id,
)
def _get_power_level(self, txn, room_id, user_id):
sql = (
"SELECT level FROM room_power_levels as r "
"INNER JOIN current_state_events as c "
"ON r.event_id = c.event_id "
"WHERE c.room_id = ? AND r.user_id = ? "
)
rows = txn.execute(sql, (room_id, user_id,)).fetchall()
if len(rows) == 1:
return rows[0][0]
sql = (
"SELECT level FROM room_default_levels as r "
"INNER JOIN current_state_events as c "
"ON r.event_id = c.event_id "
"WHERE c.room_id = ? "
)
rows = txn.execute(sql, (room_id,)).fetchall()
if len(rows) == 1:
return rows[0][0]
else:
return None
def get_ops_levels(self, room_id):
return self.runInteraction(
self._get_ops_levels,
room_id,
)
def _get_ops_levels(self, txn, room_id):
sql = (
"SELECT ban_level, kick_level, redact_level "
"FROM room_ops_levels as r "
"INNER JOIN current_state_events as c "
"ON r.event_id = c.event_id "
"WHERE c.room_id = ? "
)
rows = txn.execute(sql, (room_id,)).fetchall()
if len(rows) == 1:
return OpsLevel(rows[0][0], rows[0][1], rows[0][2])
else:
return OpsLevel(None, None)
def get_add_state_level(self, room_id):
return self._get_level_from_table("room_add_state_levels", room_id)
def get_send_event_level(self, room_id):
return self._get_level_from_table("room_send_event_levels", room_id)
@defer.inlineCallbacks
def _get_level_from_table(self, table, room_id):
sql = (
"SELECT level FROM %(table)s as r "
"INNER JOIN current_state_events as c "
"ON r.event_id = c.event_id "
"WHERE c.room_id = ? "
) % {"table": table}
rows = yield self._execute(None, sql, room_id)
if len(rows) == 1:
defer.returnValue(rows[0][0])
else:
defer.returnValue(None)
def _store_room_topic_txn(self, txn, event):
self._simple_insert_txn(
txn,
"topics",
{
"event_id": event.event_id,
"room_id": event.room_id,
"topic": event.topic,
}
)
if hasattr(event, "topic"):
self._simple_insert_txn(
txn,
"topics",
{
"event_id": event.event_id,
"room_id": event.room_id,
"topic": event.topic,
}
)
def _store_room_name_txn(self, txn, event):
self._simple_insert_txn(
txn,
"room_names",
{
"event_id": event.event_id,
"room_id": event.room_id,
"name": event.name,
}
)
def _store_join_rule(self, txn, event):
self._simple_insert_txn(
txn,
"room_join_rules",
{
"event_id": event.event_id,
"room_id": event.room_id,
"join_rule": event.content["join_rule"],
},
)
def _store_power_levels(self, txn, event):
for user_id, level in event.content.items():
if user_id == "default":
self._simple_insert_txn(
txn,
"room_default_levels",
{
"event_id": event.event_id,
"room_id": event.room_id,
"level": level,
},
)
else:
self._simple_insert_txn(
txn,
"room_power_levels",
{
"event_id": event.event_id,
"room_id": event.room_id,
"user_id": user_id,
"level": level
},
)
def _store_default_level(self, txn, event):
self._simple_insert_txn(
txn,
"room_default_levels",
{
"event_id": event.event_id,
"room_id": event.room_id,
"level": event.content["default_level"],
},
)
def _store_add_state_level(self, txn, event):
self._simple_insert_txn(
txn,
"room_add_state_levels",
{
"event_id": event.event_id,
"room_id": event.room_id,
"level": event.content["level"],
},
)
def _store_send_event_level(self, txn, event):
self._simple_insert_txn(
txn,
"room_send_event_levels",
{
"event_id": event.event_id,
"room_id": event.room_id,
"level": event.content["level"],
},
)
def _store_ops_level(self, txn, event):
content = {
"event_id": event.event_id,
"room_id": event.room_id,
}
if "kick_level" in event.content:
content["kick_level"] = event.content["kick_level"]
if "ban_level" in event.content:
content["ban_level"] = event.content["ban_level"]
if "redact_level" in event.content:
content["redact_level"] = event.content["redact_level"]
self._simple_insert_txn(
txn,
"room_ops_levels",
content,
)
if hasattr(event, "name"):
self._simple_insert_txn(
txn,
"room_names",
{
"event_id": event.event_id,
"room_id": event.room_id,
"name": event.name,
}
)
class RoomsTable(Table):

View File

@ -1,31 +0,0 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS context_edge_pdus(
id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
pdu_id TEXT,
origin TEXT,
context TEXT,
CONSTRAINT context_edge_pdu_id_origin UNIQUE (pdu_id, origin)
);
CREATE TABLE IF NOT EXISTS origin_edge_pdus(
id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
pdu_id TEXT,
origin TEXT,
CONSTRAINT origin_edge_pdu_id_origin UNIQUE (pdu_id, origin)
);
CREATE INDEX IF NOT EXISTS context_edge_pdu_id ON context_edge_pdus(pdu_id, origin);
CREATE INDEX IF NOT EXISTS origin_edge_pdu_id ON origin_edge_pdus(pdu_id, origin);

View File

@ -0,0 +1,75 @@
CREATE TABLE IF NOT EXISTS event_forward_extremities(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
);
CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id);
CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_backward_extremities(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
);
CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id);
CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_edges(
event_id TEXT NOT NULL,
prev_event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
is_state INTEGER NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id, is_state)
);
CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id);
CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id);
CREATE TABLE IF NOT EXISTS room_depth(
room_id TEXT NOT NULL,
min_depth INTEGER NOT NULL,
CONSTRAINT uniqueness UNIQUE (room_id)
);
CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id);
create TABLE IF NOT EXISTS event_destinations(
event_id TEXT NOT NULL,
destination TEXT NOT NULL,
delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE
);
CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id);
CREATE TABLE IF NOT EXISTS state_forward_extremities(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
);
CREATE INDEX IF NOT EXISTS st_extrem_keys ON state_forward_extremities(
room_id, type, state_key
);
CREATE INDEX IF NOT EXISTS st_extrem_id ON state_forward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_auth(
event_id TEXT NOT NULL,
auth_id TEXT NOT NULL,
room_id TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, auth_id, room_id)
);
CREATE INDEX IF NOT EXISTS evauth_edges_id ON event_auth(event_id);
CREATE INDEX IF NOT EXISTS evauth_edges_auth_id ON event_auth(auth_id);

View File

@ -0,0 +1,65 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS event_content_hashes (
event_id TEXT,
algorithm TEXT,
hash BLOB,
CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
);
CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes(
event_id
);
CREATE TABLE IF NOT EXISTS event_reference_hashes (
event_id TEXT,
algorithm TEXT,
hash BLOB,
CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
);
CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes (
event_id
);
CREATE TABLE IF NOT EXISTS event_origin_signatures (
event_id TEXT,
origin TEXT,
key_id TEXT,
signature BLOB,
CONSTRAINT uniqueness UNIQUE (event_id, key_id)
);
CREATE INDEX IF NOT EXISTS event_origin_signatures_id ON event_origin_signatures (
event_id
);
CREATE TABLE IF NOT EXISTS event_edge_hashes(
event_id TEXT,
prev_event_id TEXT,
algorithm TEXT,
hash BLOB,
CONSTRAINT uniqueness UNIQUE (
event_id, prev_event_id, algorithm
)
);
CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes(
event_id
);

View File

@ -23,6 +23,7 @@ CREATE TABLE IF NOT EXISTS events(
unrecognized_keys TEXT,
processed BOOL NOT NULL,
outlier BOOL NOT NULL,
depth INTEGER DEFAULT 0 NOT NULL,
CONSTRAINT ev_uniq UNIQUE (event_id)
);
@ -84,80 +85,24 @@ CREATE TABLE IF NOT EXISTS topics(
topic TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS topics_event_id ON topics(event_id);
CREATE INDEX IF NOT EXISTS topics_room_id ON topics(room_id);
CREATE TABLE IF NOT EXISTS room_names(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
name TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS room_names_event_id ON room_names(event_id);
CREATE INDEX IF NOT EXISTS room_names_room_id ON room_names(room_id);
CREATE TABLE IF NOT EXISTS rooms(
room_id TEXT PRIMARY KEY NOT NULL,
is_public INTEGER,
creator TEXT
);
CREATE TABLE IF NOT EXISTS room_join_rules(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
join_rule TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS room_join_rules_event_id ON room_join_rules(event_id);
CREATE INDEX IF NOT EXISTS room_join_rules_room_id ON room_join_rules(room_id);
CREATE TABLE IF NOT EXISTS room_power_levels(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
user_id TEXT NOT NULL,
level INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS room_power_levels_event_id ON room_power_levels(event_id);
CREATE INDEX IF NOT EXISTS room_power_levels_room_id ON room_power_levels(room_id);
CREATE INDEX IF NOT EXISTS room_power_levels_room_user ON room_power_levels(room_id, user_id);
CREATE TABLE IF NOT EXISTS room_default_levels(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
level INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS room_default_levels_event_id ON room_default_levels(event_id);
CREATE INDEX IF NOT EXISTS room_default_levels_room_id ON room_default_levels(room_id);
CREATE TABLE IF NOT EXISTS room_add_state_levels(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
level INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS room_add_state_levels_event_id ON room_add_state_levels(event_id);
CREATE INDEX IF NOT EXISTS room_add_state_levels_room_id ON room_add_state_levels(room_id);
CREATE TABLE IF NOT EXISTS room_send_event_levels(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
level INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS room_send_event_levels_event_id ON room_send_event_levels(event_id);
CREATE INDEX IF NOT EXISTS room_send_event_levels_room_id ON room_send_event_levels(room_id);
CREATE TABLE IF NOT EXISTS room_ops_levels(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
ban_level INTEGER,
kick_level INTEGER,
redact_level INTEGER
);
CREATE INDEX IF NOT EXISTS room_ops_levels_event_id ON room_ops_levels(event_id);
CREATE INDEX IF NOT EXISTS room_ops_levels_room_id ON room_ops_levels(room_id);
CREATE TABLE IF NOT EXISTS room_hosts(
room_id TEXT NOT NULL,
host TEXT NOT NULL,

View File

@ -1,106 +0,0 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Stores pdus and their content
CREATE TABLE IF NOT EXISTS pdus(
pdu_id TEXT,
origin TEXT,
context TEXT,
pdu_type TEXT,
ts INTEGER,
depth INTEGER DEFAULT 0 NOT NULL,
is_state BOOL,
content_json TEXT,
unrecognized_keys TEXT,
outlier BOOL NOT NULL,
have_processed BOOL,
CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
);
-- Stores what the current state pdu is for a given (context, pdu_type, key) tuple
CREATE TABLE IF NOT EXISTS state_pdus(
pdu_id TEXT,
origin TEXT,
context TEXT,
pdu_type TEXT,
state_key TEXT,
power_level TEXT,
prev_state_id TEXT,
prev_state_origin TEXT,
CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin)
);
CREATE TABLE IF NOT EXISTS current_state(
pdu_id TEXT,
origin TEXT,
context TEXT,
pdu_type TEXT,
state_key TEXT,
CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE
);
-- Stores where each pdu we want to send should be sent and the delivery status.
create TABLE IF NOT EXISTS pdu_destinations(
pdu_id TEXT,
origin TEXT,
destination TEXT,
delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE
);
CREATE TABLE IF NOT EXISTS pdu_forward_extremities(
pdu_id TEXT,
origin TEXT,
context TEXT,
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
);
CREATE TABLE IF NOT EXISTS pdu_backward_extremities(
pdu_id TEXT,
origin TEXT,
context TEXT,
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
);
CREATE TABLE IF NOT EXISTS pdu_edges(
pdu_id TEXT,
origin TEXT,
prev_pdu_id TEXT,
prev_origin TEXT,
context TEXT,
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context)
);
CREATE TABLE IF NOT EXISTS context_depth(
context TEXT,
min_depth INTEGER,
CONSTRAINT uniqueness UNIQUE (context)
);
CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context);
CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin);
CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin);
-- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination);
CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context);
CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin);
CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin);
CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context);

View File

@ -0,0 +1,33 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS state_groups(
id INTEGER PRIMARY KEY,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS state_groups_state(
state_group INTEGER NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
event_id TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS event_to_state_groups(
event_id TEXT NOT NULL,
state_group INTEGER NOT NULL
);

View File

@ -0,0 +1,177 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from _base import SQLBaseStore
class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes"""
def _get_event_content_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given Event.
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
A dict of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
" FROM event_content_hashes"
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
return dict(txn.fetchall())
def _store_event_content_hash_txn(self, txn, event_id, algorithm,
hash_bytes):
"""Store a hash for a Event
Args:
txn (cursor):
event_id (str): Id for the Event.
algorithm (str): Hashing algorithm.
hash_bytes (bytes): Hash function output bytes.
"""
self._simple_insert_txn(
txn,
"event_content_hashes",
{
"event_id": event_id,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
},
or_ignore=True,
)
def get_event_reference_hashes(self, event_ids):
def f(txn):
return [
self._get_event_reference_hashes_txn(txn, ev)
for ev in event_ids
]
return self.runInteraction(
"get_event_reference_hashes",
f
)
def _get_event_reference_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given PDU.
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
A dict of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
" FROM event_reference_hashes"
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
return dict(txn.fetchall())
def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
hash_bytes):
"""Store a hash for a PDU
Args:
txn (cursor):
event_id (str): Id for the Event.
algorithm (str): Hashing algorithm.
hash_bytes (bytes): Hash function output bytes.
"""
self._simple_insert_txn(
txn,
"event_reference_hashes",
{
"event_id": event_id,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
},
or_ignore=True,
)
def _get_event_origin_signatures_txn(self, txn, event_id):
"""Get all the signatures for a given PDU.
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
A dict of key_id -> signature_bytes.
"""
query = (
"SELECT key_id, signature"
" FROM event_origin_signatures"
" WHERE event_id = ? "
)
txn.execute(query, (event_id, ))
return dict(txn.fetchall())
def _store_event_origin_signature_txn(self, txn, event_id, origin, key_id,
signature_bytes):
"""Store a signature from the origin server for a PDU.
Args:
txn (cursor):
event_id (str): Id for the Event.
origin (str): origin of the Event.
key_id (str): Id for the signing key.
signature (bytes): The signature.
"""
self._simple_insert_txn(
txn,
"event_origin_signatures",
{
"event_id": event_id,
"origin": origin,
"key_id": key_id,
"signature": buffer(signature_bytes),
},
or_ignore=True,
)
def _get_prev_event_hashes_txn(self, txn, event_id):
"""Get all the hashes for previous PDUs of a PDU
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
"""
query = (
"SELECT prev_event_id, algorithm, hash"
" FROM event_edge_hashes"
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
results = {}
for prev_event_id, algorithm, hash_bytes in txn.fetchall():
hashes = results.setdefault(prev_event_id, {})
hashes[algorithm] = hash_bytes
return results
def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id,
algorithm, hash_bytes):
self._simple_insert_txn(
txn,
"event_edge_hashes",
{
"event_id": event_id,
"prev_event_id": prev_event_id,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
},
or_ignore=True,
)

96
synapse/storage/state.py Normal file
View File

@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from twisted.internet import defer
class StateStore(SQLBaseStore):
@defer.inlineCallbacks
def get_state_groups(self, event_ids):
groups = set()
for event_id in event_ids:
group = yield self._simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
allow_none=True,
)
if group:
groups.add(group)
res = {}
for group in groups:
state_ids = yield self._simple_select_onecol(
table="state_groups_state",
keyvalues={"state_group": group},
retcol="event_id",
)
state = []
for state_id in state_ids:
s = yield self.get_event(
state_id,
allow_none=True,
)
if s:
state.append(s)
res[group] = state
defer.returnValue(res)
def store_state_groups(self, event):
return self.runInteraction(
"store_state_groups",
self._store_state_groups_txn, event
)
def _store_state_groups_txn(self, txn, event):
if not event.state_events:
return
state_group = event.state_group
if not state_group:
state_group = self._simple_insert_txn(
txn,
table="state_groups",
values={
"room_id": event.room_id,
"event_id": event.event_id,
}
)
for state in event.state_events.values():
self._simple_insert_txn(
txn,
table="state_groups_state",
values={
"state_group": state_group,
"room_id": state.room_id,
"type": state.type,
"state_key": state.state_key,
"event_id": state.event_id,
}
)
self._simple_insert_txn(
txn,
table="event_to_state_groups",
values={
"state_group": state_group,
"event_id": event.event_id,
}
)

View File

@ -177,10 +177,9 @@ class StreamStore(SQLBaseStore):
sql = (
"SELECT *, (%(redacted)s) AS redacted FROM events AS e WHERE "
"((room_id IN (%(current)s)) OR "
"(e.outlier = 0 AND (room_id IN (%(current)s)) OR "
"(event_id IN (%(invites)s))) "
"AND e.stream_ordering > ? AND e.stream_ordering <= ? "
"AND e.outlier = 0 "
"ORDER BY stream_ordering ASC LIMIT %(limit)d "
) % {
"redacted": del_sql,
@ -309,7 +308,10 @@ class StreamStore(SQLBaseStore):
defer.returnValue(ret)
def get_room_events_max_id(self):
return self.runInteraction(self._get_room_events_max_id_txn)
return self.runInteraction(
"get_room_events_max_id",
self._get_room_events_max_id_txn
)
def _get_room_events_max_id_txn(self, txn):
txn.execute(

View File

@ -14,7 +14,6 @@
# limitations under the License.
from ._base import SQLBaseStore, Table
from .pdu import PdusTable
from collections import namedtuple
@ -42,6 +41,7 @@ class TransactionStore(SQLBaseStore):
"""
return self.runInteraction(
"get_received_txn_response",
self._get_received_txn_response, transaction_id, origin
)
@ -73,6 +73,7 @@ class TransactionStore(SQLBaseStore):
"""
return self.runInteraction(
"set_received_txn_response",
self._set_received_txn_response,
transaction_id, origin, code, response_dict
)
@ -88,7 +89,7 @@ class TransactionStore(SQLBaseStore):
txn.execute(query, (code, response_json, transaction_id, origin))
def prep_send_transaction(self, transaction_id, destination,
origin_server_ts, pdu_list):
origin_server_ts):
"""Persists an outgoing transaction and calculates the values for the
previous transaction id list.
@ -99,19 +100,19 @@ class TransactionStore(SQLBaseStore):
transaction_id (str)
destination (str)
origin_server_ts (int)
pdu_list (list)
Returns:
list: A list of previous transaction ids.
"""
return self.runInteraction(
"prep_send_transaction",
self._prep_send_transaction,
transaction_id, destination, origin_server_ts, pdu_list
transaction_id, destination, origin_server_ts
)
def _prep_send_transaction(self, txn, transaction_id, destination,
origin_server_ts, pdu_list):
origin_server_ts):
# First we find out what the prev_txs should be.
# Since we know that we are only sending one transaction at a time,
@ -139,15 +140,15 @@ class TransactionStore(SQLBaseStore):
# Update the tx id -> pdu id mapping
values = [
(transaction_id, destination, pdu[0], pdu[1])
for pdu in pdu_list
]
logger.debug("Inserting: %s", repr(values))
query = TransactionsToPduTable.insert_statement()
txn.executemany(query, values)
# values = [
# (transaction_id, destination, pdu[0], pdu[1])
# for pdu in pdu_list
# ]
#
# logger.debug("Inserting: %s", repr(values))
#
# query = TransactionsToPduTable.insert_statement()
# txn.executemany(query, values)
return prev_txns
@ -161,6 +162,7 @@ class TransactionStore(SQLBaseStore):
response_json (str)
"""
return self.runInteraction(
"delivered_txn",
self._delivered_txn,
transaction_id, destination, code, response_dict
)
@ -186,6 +188,7 @@ class TransactionStore(SQLBaseStore):
list: A list of `ReceivedTransactionsTable.EntryType`
"""
return self.runInteraction(
"get_transactions_after",
self._get_transactions_after, transaction_id, destination
)
@ -202,49 +205,6 @@ class TransactionStore(SQLBaseStore):
return ReceivedTransactionsTable.decode_results(txn.fetchall())
def get_pdus_after_transaction(self, transaction_id, destination):
"""For a given local transaction_id that we sent to a given destination
home server, return a list of PDUs that were sent to that destination
after it.
Args:
txn
transaction_id (str)
destination (str)
Returns
list: A list of PduTuple
"""
return self.runInteraction(
self._get_pdus_after_transaction,
transaction_id, destination
)
def _get_pdus_after_transaction(self, txn, transaction_id, destination):
# Query that first get's all transaction_ids with an id greater than
# the one given from the `sent_transactions` table. Then JOIN on this
# from the `tx->pdu` table to get a list of (pdu_id, origin) that
# specify the pdus that were sent in those transactions.
query = (
"SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp "
"INNER JOIN %(sent_tx)s as st "
"ON tp.transaction_id = st.transaction_id "
"AND tp.destination = st.destination "
"WHERE st.id > ("
"SELECT id FROM %(sent_tx)s "
"WHERE transaction_id = ? AND destination = ?"
) % {
"tx_pdu": TransactionsToPduTable.table_name,
"sent_tx": SentTransactions.table_name,
}
txn.execute(query, (transaction_id, destination))
pdus = PdusTable.decode_results(txn.fetchall())
return self._get_pdu_tuples(txn, pdus)
class ReceivedTransactionsTable(Table):
table_name = "received_transactions"

View File

@ -78,6 +78,11 @@ class DomainSpecificString(
"""Create a structure on the local domain"""
return cls(localpart=localpart, domain=hs.hostname, is_mine=True)
@classmethod
def create(cls, localpart, domain, hs):
is_mine = domain == hs.hostname
return cls(localpart=localpart, domain=domain, is_mine=is_mine)
class UserID(DomainSpecificString):
"""Structure representing a user ID."""
@ -94,6 +99,11 @@ class RoomID(DomainSpecificString):
SIGIL = "!"
class EventID(DomainSpecificString):
"""Structure representing an event id. """
SIGIL = "$"
class StreamToken(
namedtuple(
"Token",

View File

@ -21,3 +21,10 @@ def sleep(seconds):
d = defer.Deferred()
reactor.callLater(seconds, d.callback, seconds)
return d
def run_on_reactor():
""" This will cause the rest of the function to be invoked upon the next
iteration of the main loop
"""
return sleep(0)

View File

@ -80,7 +80,7 @@ class JsonEncodedObject(object):
def get_full_dict(self):
d = {
k: v for (k, v) in self.__dict__.items()
k: _encode(v) for (k, v) in self.__dict__.items()
if k in self.valid_keys or k in self.internal_keys
}
d.update(self.unrecognized_keys)

View File

@ -14,6 +14,8 @@
# limitations under the License.
from synapse.api.events import SynapseEvent
from synapse.api.events.validator import EventValidator
from synapse.api.errors import SynapseError
from tests import unittest
@ -21,7 +23,7 @@ from tests import unittest
class SynapseTemplateCheckTestCase(unittest.TestCase):
def setUp(self):
pass
self.validator = EventValidator(None)
def tearDown(self):
pass
@ -38,22 +40,28 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
}
event = MockSynapseEvent(template)
self.assertTrue(event.check_json(content, raises=False))
event.content = content
self.assertTrue(self.validator.validate(event))
content = {
"person": {"name": "bob"},
"friends": ["jill"],
"enemies": ["mike"]
}
event = MockSynapseEvent(template)
self.assertTrue(event.check_json(content, raises=False))
event.content = content
self.assertTrue(self.validator.validate(event))
content = {
"person": {"name": "bob"},
# missing friends
"enemies": ["mike", "jill"]
}
self.assertFalse(event.check_json(content, raises=False))
event.content = content
self.assertRaises(
SynapseError,
self.validator.validate,
event
)
def test_lists(self):
template = {
@ -67,13 +75,19 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
}
event = MockSynapseEvent(template)
self.assertFalse(event.check_json(content, raises=False))
event.content = content
self.assertRaises(
SynapseError,
self.validator.validate,
event
)
content = {
"person": {"name": "bob"},
"friends": [{"name": "jill"}, {"name": "mike"}]
}
self.assertTrue(event.check_json(content, raises=False))
event.content = content
self.assertTrue(self.validator.validate(event))
def test_nested_lists(self):
template = {
@ -103,7 +117,12 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
}
event = MockSynapseEvent(template)
self.assertFalse(event.check_json(content, raises=False))
event.content = content
self.assertRaises(
SynapseError,
self.validator.validate,
event
)
content = {
"results": {
@ -117,7 +136,8 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
]
}
}
self.assertTrue(event.check_json(content, raises=False))
event.content = content
self.assertTrue(self.validator.validate(event))
def test_nested_keys(self):
template = {
@ -145,7 +165,8 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
}
}
self.assertTrue(event.check_json(content, raises=False))
event.content = content
self.assertTrue(self.validator.validate(event))
content = {
"person": {
@ -159,7 +180,12 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
}
}
self.assertFalse(event.check_json(content, raises=False))
event.content = content
self.assertRaises(
SynapseError,
self.validator.validate,
event
)
content = {
"person": {
@ -173,7 +199,12 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
}
}
self.assertFalse(event.check_json(content, raises=False))
event.content = content
self.assertRaises(
SynapseError,
self.validator.validate,
event
)
class MockSynapseEvent(SynapseEvent):

View File

@ -24,7 +24,6 @@ from ..utils import MockHttpResource, MockClock, MockKey
from synapse.server import HomeServer
from synapse.federation import initialize_http_replication
from synapse.federation.units import Pdu
from synapse.storage.pdu import PduTuple, PduEntry
def make_pdu(prev_pdus=[], **kwargs):
@ -41,7 +40,7 @@ def make_pdu(prev_pdus=[], **kwargs):
}
pdu_fields.update(kwargs)
return PduTuple(PduEntry(**pdu_fields), prev_pdus)
return Pdu(prev_pdus=prev_pdus, **pdu_fields)
class FederationTestCase(unittest.TestCase):
@ -52,177 +51,185 @@ class FederationTestCase(unittest.TestCase):
"put_json",
])
self.mock_persistence = Mock(spec=[
"get_current_state_for_context",
"get_pdu",
"persist_event",
"update_min_depth_for_context",
"prep_send_transaction",
"delivered_txn",
"get_received_txn_response",
"set_received_txn_response",
])
self.mock_persistence.get_received_txn_response.return_value = (
defer.succeed(None)
defer.succeed(None)
)
self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()]
self.clock = MockClock()
hs = HomeServer("test",
resource_for_federation=self.mock_resource,
http_client=self.mock_http_client,
db_pool=None,
datastore=self.mock_persistence,
clock=self.clock,
config=self.mock_config,
keyring=Mock(),
hs = HomeServer(
"test",
resource_for_federation=self.mock_resource,
http_client=self.mock_http_client,
db_pool=None,
datastore=self.mock_persistence,
clock=self.clock,
config=self.mock_config,
keyring=Mock(),
)
self.federation = initialize_http_replication(hs)
self.distributor = hs.get_distributor()
@defer.inlineCallbacks
def test_get_state(self):
self.mock_persistence.get_current_state_for_context.return_value = (
defer.succeed([])
)
mock_handler = Mock(spec=[
"get_state_for_pdu",
])
self.federation.set_handler(mock_handler)
mock_handler.get_state_for_pdu.return_value = defer.succeed([])
# Empty context initially
(code, response) = yield self.mock_resource.trigger("GET",
"/_matrix/federation/v1/state/my-context/", None)
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/state/my-context/",
None
)
self.assertEquals(200, code)
self.assertFalse(response["pdus"])
# Now lets give the context some state
self.mock_persistence.get_current_state_for_context.return_value = (
mock_handler.get_state_for_pdu.return_value = (
defer.succeed([
make_pdu(
pdu_id="the-pdu-id",
event_id="the-pdu-id",
origin="red",
context="my-context",
pdu_type="m.topic",
ts=123456789000,
room_id="my-context",
type="m.topic",
origin_server_ts=123456789000,
depth=1,
is_state=True,
content_json='{"topic":"The topic"}',
content={"topic": "The topic"},
state_key="",
power_level=1000,
prev_state_id="last-pdu-id",
prev_state_origin="blue",
prev_state="last-pdu-id",
),
])
)
(code, response) = yield self.mock_resource.trigger("GET",
"/_matrix/federation/v1/state/my-context/", None)
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/state/my-context/",
None
)
self.assertEquals(200, code)
self.assertEquals(1, len(response["pdus"]))
@defer.inlineCallbacks
def test_get_pdu(self):
self.mock_persistence.get_pdu.return_value = (
mock_handler = Mock(spec=[
"get_persisted_pdu",
])
self.federation.set_handler(mock_handler)
mock_handler.get_persisted_pdu.return_value = (
defer.succeed(None)
)
(code, response) = yield self.mock_resource.trigger("GET",
"/_matrix/federation/v1/pdu/red/abc123def456/", None)
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/event/abc123def456/",
None
)
self.assertEquals(404, code)
# Now insert such a PDU
self.mock_persistence.get_pdu.return_value = (
mock_handler.get_persisted_pdu.return_value = (
defer.succeed(
make_pdu(
pdu_id="abc123def456",
event_id="abc123def456",
origin="red",
context="my-context",
pdu_type="m.text",
ts=123456789001,
room_id="my-context",
type="m.text",
origin_server_ts=123456789001,
depth=1,
content_json='{"text":"Here is the message"}',
content={"text": "Here is the message"},
)
)
)
(code, response) = yield self.mock_resource.trigger("GET",
"/_matrix/federation/v1/pdu/red/abc123def456/", None)
(code, response) = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/event/abc123def456/",
None
)
self.assertEquals(200, code)
self.assertEquals(1, len(response["pdus"]))
self.assertEquals("m.text", response["pdus"][0]["pdu_type"])
self.assertEquals("m.text", response["pdus"][0]["type"])
@defer.inlineCallbacks
def test_send_pdu(self):
self.mock_http_client.put_json.return_value = defer.succeed(
(200, "OK")
(200, "OK")
)
pdu = Pdu(
pdu_id="abc123def456",
origin="red",
destinations=["remote"],
context="my-context",
origin_server_ts=123456789002,
pdu_type="m.test",
content={"testing": "content here"},
depth=1,
event_id="abc123def456",
origin="red",
room_id="my-context",
type="m.text",
origin_server_ts=123456789001,
depth=1,
content={"text": "Here is the message"},
destinations=["remote"],
)
yield self.federation.send_pdu(pdu)
self.mock_http_client.put_json.assert_called_with(
"remote",
path="/_matrix/federation/v1/send/1000000/",
data={
"origin_server_ts": 1000000,
"origin": "test",
"pdus": [
{
"origin": "red",
"pdu_id": "abc123def456",
"prev_pdus": [],
"origin_server_ts": 123456789002,
"context": "my-context",
"pdu_type": "m.test",
"is_state": False,
"content": {"testing": "content here"},
"depth": 1,
},
]
},
json_data_callback=ANY,
"remote",
path="/_matrix/federation/v1/send/1000000/",
data={
"origin_server_ts": 1000000,
"origin": "test",
"pdus": [
pdu.get_dict(),
],
'pdu_failures': [],
},
json_data_callback=ANY,
)
@defer.inlineCallbacks
def test_send_edu(self):
self.mock_http_client.put_json.return_value = defer.succeed(
(200, "OK")
(200, "OK")
)
yield self.federation.send_edu(
destination="remote",
edu_type="m.test",
content={"testing": "content here"},
destination="remote",
edu_type="m.test",
content={"testing": "content here"},
)
# MockClock ensures we can guess these timestamps
self.mock_http_client.put_json.assert_called_with(
"remote",
path="/_matrix/federation/v1/send/1000000/",
data={
"origin": "test",
"origin_server_ts": 1000000,
"pdus": [],
"edus": [
{
# TODO: SYN-103: Remove "origin" and "destination"
"origin": "test",
"destination": "remote",
"edu_type": "m.test",
"content": {"testing": "content here"},
}
],
},
json_data_callback=ANY,
"remote",
path="/_matrix/federation/v1/send/1000000/",
data={
"origin": "test",
"origin_server_ts": 1000000,
"pdus": [],
"edus": [
{
# TODO: SYN-103: Remove "origin" and "destination"
"origin": "test",
"destination": "remote",
"edu_type": "m.test",
"content": {"testing": "content here"},
}
],
'pdu_failures': [],
},
json_data_callback=ANY,
)
@defer.inlineCallbacks
def test_recv_edu(self):
recv_observer = Mock()
@ -230,24 +237,26 @@ class FederationTestCase(unittest.TestCase):
self.federation.register_edu_handler("m.test", recv_observer)
yield self.mock_resource.trigger("PUT",
"/_matrix/federation/v1/send/1001000/",
"""{
"origin": "remote",
"origin_server_ts": 1001000,
"pdus": [],
"edus": [
{
"origin": "remote",
"destination": "test",
"edu_type": "m.test",
"content": {"testing": "reply here"}
}
]
}""")
yield self.mock_resource.trigger(
"PUT",
"/_matrix/federation/v1/send/1001000/",
"""{
"origin": "remote",
"origin_server_ts": 1001000,
"pdus": [],
"edus": [
{
"origin": "remote",
"destination": "test",
"edu_type": "m.test",
"content": {"testing": "reply here"}
}
]
}"""
)
recv_observer.assert_called_with(
"remote", {"testing": "reply here"}
"remote", {"testing": "reply here"}
)
@defer.inlineCallbacks
@ -278,8 +287,11 @@ class FederationTestCase(unittest.TestCase):
self.federation.register_query_handler("a-question", recv_handler)
code, response = yield self.mock_resource.trigger("GET",
"/_matrix/federation/v1/query/a-question?three=3&four=4", None)
code, response = yield self.mock_resource.trigger(
"GET",
"/_matrix/federation/v1/query/a-question?three=3&four=4",
None
)
self.assertEquals(200, code)
self.assertEquals({"another": "response"}, response)

View File

@ -1,160 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from synapse.federation.pdu_codec import (
PduCodec, encode_event_id, decode_event_id
)
from synapse.federation.units import Pdu
#from synapse.api.events.room import MessageEvent
from synapse.server import HomeServer
from mock import Mock
class PduCodecTestCase(unittest.TestCase):
def setUp(self):
self.hs = HomeServer("blargle.net")
self.event_factory = self.hs.get_event_factory()
self.codec = PduCodec(self.hs)
def test_decode_event_id(self):
self.assertEquals(
("foo", "bar.com"),
decode_event_id("foo@bar.com", "A")
)
self.assertEquals(
("foo", "bar.com"),
decode_event_id("foo", "bar.com")
)
def test_encode_event_id(self):
self.assertEquals("A@B", encode_event_id("A", "B"))
def test_codec_event_id(self):
event_id = "aa@bb.com"
self.assertEquals(
event_id,
encode_event_id(*decode_event_id(event_id, None))
)
pdu_id = ("aa", "bb.com")
self.assertEquals(
pdu_id,
decode_event_id(encode_event_id(*pdu_id), None)
)
def test_event_from_pdu(self):
pdu = Pdu(
pdu_id="foo",
context="rooooom",
pdu_type="m.room.message",
origin="bar.com",
origin_server_ts=12345,
depth=5,
prev_pdus=[("alice", "bob.com")],
is_state=False,
content={"msgtype": u"test"},
)
event = self.codec.event_from_pdu(pdu)
self.assertEquals("foo@bar.com", event.event_id)
self.assertEquals(pdu.context, event.room_id)
self.assertEquals(pdu.is_state, event.is_state)
self.assertEquals(pdu.depth, event.depth)
self.assertEquals(["alice@bob.com"], event.prev_events)
self.assertEquals(pdu.content, event.content)
def test_pdu_from_event(self):
event = self.event_factory.create_event(
etype="m.room.message",
event_id="gargh_id",
room_id="rooom",
user_id="sender",
content={"msgtype": u"test"},
)
pdu = self.codec.pdu_from_event(event)
self.assertEquals(event.event_id, pdu.pdu_id)
self.assertEquals(self.hs.hostname, pdu.origin)
self.assertEquals(event.room_id, pdu.context)
self.assertEquals(event.content, pdu.content)
self.assertEquals(event.type, pdu.pdu_type)
event = self.event_factory.create_event(
etype="m.room.message",
event_id="gargh_id@bob.com",
room_id="rooom",
user_id="sender",
content={"msgtype": u"test"},
)
pdu = self.codec.pdu_from_event(event)
self.assertEquals("gargh_id", pdu.pdu_id)
self.assertEquals("bob.com", pdu.origin)
self.assertEquals(event.room_id, pdu.context)
self.assertEquals(event.content, pdu.content)
self.assertEquals(event.type, pdu.pdu_type)
def test_event_from_state_pdu(self):
pdu = Pdu(
pdu_id="foo",
context="rooooom",
pdu_type="m.room.topic",
origin="bar.com",
origin_server_ts=12345,
depth=5,
prev_pdus=[("alice", "bob.com")],
is_state=True,
content={"topic": u"test"},
state_key="",
)
event = self.codec.event_from_pdu(pdu)
self.assertEquals("foo@bar.com", event.event_id)
self.assertEquals(pdu.context, event.room_id)
self.assertEquals(pdu.is_state, event.is_state)
self.assertEquals(pdu.depth, event.depth)
self.assertEquals(["alice@bob.com"], event.prev_events)
self.assertEquals(pdu.content, event.content)
self.assertEquals(pdu.state_key, event.state_key)
def test_pdu_from_state_event(self):
event = self.event_factory.create_event(
etype="m.room.topic",
event_id="gargh_id",
room_id="rooom",
user_id="sender",
content={"topic": u"test"},
)
pdu = self.codec.pdu_from_event(event)
self.assertEquals(event.event_id, pdu.pdu_id)
self.assertEquals(self.hs.hostname, pdu.origin)
self.assertEquals(event.room_id, pdu.context)
self.assertEquals(event.content, pdu.content)
self.assertEquals(event.type, pdu.pdu_type)
self.assertEquals(event.state_key, pdu.state_key)

View File

@ -21,9 +21,8 @@ from mock import Mock
from synapse.server import HomeServer
from synapse.handlers.directory import DirectoryHandler
from synapse.storage.directory import RoomAliasMapping
from tests.utils import SQLiteMemoryDbPool
from tests.utils import SQLiteMemoryDbPool, MockKey
class DirectoryHandlers(object):
@ -41,6 +40,7 @@ class DirectoryTestCase(unittest.TestCase):
])
self.query_handlers = {}
def register_query_handler(query_type, handler):
self.query_handlers[query_type] = handler
self.mock_federation.register_query_handler = register_query_handler
@ -48,11 +48,16 @@ class DirectoryTestCase(unittest.TestCase):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer("test",
self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer(
"test",
db_pool=db_pool,
http_client=None,
resource_for_federation=Mock(),
replication_layer=self.mock_federation,
config=self.mock_config,
)
hs.handlers = DirectoryHandlers(hs)

View File

@ -17,16 +17,15 @@ from twisted.internet import defer
from tests import unittest
from synapse.api.events.room import (
InviteJoinEvent, MessageEvent, RoomMemberEvent
MessageEvent,
)
from synapse.api.constants import Membership
from synapse.handlers.federation import FederationHandler
from synapse.server import HomeServer
from synapse.federation.units import Pdu
from mock import NonCallableMock, ANY
from ..utils import get_mock_call_args, MockKey
from ..utils import MockKey
class FederationTestCase(unittest.TestCase):
@ -36,6 +35,14 @@ class FederationTestCase(unittest.TestCase):
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
self.state_handler = NonCallableMock(spec_set=[
"annotate_state_groups",
])
self.auth = NonCallableMock(spec_set=[
"check",
])
self.hostname = "test"
hs = HomeServer(
self.hostname,
@ -53,6 +60,8 @@ class FederationTestCase(unittest.TestCase):
"federation_handler",
]),
config=self.mock_config,
auth=self.auth,
state_handler=self.state_handler,
)
self.datastore = hs.get_datastore()
@ -65,74 +74,35 @@ class FederationTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_msg(self):
pdu = Pdu(
pdu_type=MessageEvent.TYPE,
context="foo",
type=MessageEvent.TYPE,
room_id="foo",
content={"msgtype": u"fooo"},
origin_server_ts=0,
pdu_id="a",
event_id="$a:b",
origin="b",
)
store_id = "ASD"
self.datastore.persist_event.return_value = defer.succeed(store_id)
self.datastore.persist_event.return_value = defer.succeed(None)
self.datastore.get_room.return_value = defer.succeed(True)
self.state_handler.annotate_state_groups.return_value = (
defer.succeed(False)
)
yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
self.datastore.persist_event.assert_called_once_with(
ANY, False, is_new_state=False
)
self.notifier.on_new_room_event.assert_called_once_with(ANY, extra_users=[])
@defer.inlineCallbacks
def test_invite_join_target_this(self):
room_id = "foo"
user_id = "@bob:red"
pdu = Pdu(
pdu_type=InviteJoinEvent.TYPE,
user_id=user_id,
target_host=self.hostname,
context=room_id,
content={},
origin_server_ts=0,
pdu_id="a",
origin="b",
self.state_handler.annotate_state_groups.assert_called_once_with(
ANY,
old_state=None,
)
yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
self.auth.check.assert_called_once_with(ANY, raises=True)
mem_handler = self.handlers.room_member_handler
self.assertEquals(1, mem_handler.change_membership.call_count)
call_args = get_mock_call_args(
lambda event, do_auth: None,
mem_handler.change_membership
self.notifier.on_new_room_event.assert_called_once_with(
ANY,
extra_users=[]
)
self.assertEquals(False, call_args["do_auth"])
new_event = call_args["event"]
self.assertEquals(RoomMemberEvent.TYPE, new_event.type)
self.assertEquals(room_id, new_event.room_id)
self.assertEquals(user_id, new_event.state_key)
self.assertEquals(Membership.JOIN, new_event.membership)
@defer.inlineCallbacks
def test_invite_join_target_other(self):
room_id = "foo"
user_id = "@bob:red"
pdu = Pdu(
pdu_type=InviteJoinEvent.TYPE,
user_id=user_id,
state_key="@red:not%s" % self.hostname,
context=room_id,
content={},
origin_server_ts=0,
pdu_id="a",
origin="b",
)
yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
mem_handler = self.handlers.room_member_handler
self.assertEquals(0, mem_handler.change_membership.call_count)

View File

@ -51,6 +51,7 @@ def _expect_edu(destination, edu_type, content, origin="test"):
"content": content,
}
],
"pdu_failures": [],
}
def _make_edu_json(origin, edu_type, content):

View File

@ -21,7 +21,7 @@ from twisted.internet import defer
from mock import Mock, call, ANY
from ..utils import MockClock
from ..utils import MockClock, MockKey
from synapse.server import HomeServer
from synapse.api.constants import PresenceState
@ -57,6 +57,9 @@ class PresenceAndProfileHandlers(object):
class PresenceProfilelikeDataTestCase(unittest.TestCase):
def setUp(self):
self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer("test",
clock=MockClock(),
db_pool=None,
@ -72,6 +75,7 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
resource_for_federation=Mock(),
http_client=None,
replication_layer=MockReplication(),
config=self.mock_config,
)
hs.handlers = PresenceAndProfileHandlers(hs)

View File

@ -24,7 +24,7 @@ from synapse.server import HomeServer
from synapse.handlers.profile import ProfileHandler
from synapse.api.constants import Membership
from tests.utils import SQLiteMemoryDbPool
from tests.utils import SQLiteMemoryDbPool, MockKey
class ProfileHandlers(object):
@ -49,12 +49,16 @@ class ProfileTestCase(unittest.TestCase):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
self.mock_config = Mock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer("test",
db_pool=db_pool,
http_client=None,
handlers=None,
resource_for_federation=Mock(),
replication_layer=self.mock_federation,
config=self.mock_config,
)
hs.handlers = ProfileHandlers(hs)

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from tests import unittest
from synapse.api.events.room import (
InviteJoinEvent, RoomMemberEvent, RoomConfigEvent
RoomMemberEvent,
)
from synapse.api.constants import Membership
from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler
@ -34,6 +34,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
def setUp(self):
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
self.hostname = "red"
hs = HomeServer(
self.hostname,
@ -57,13 +58,16 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"profile_handler",
"federation_handler",
]),
auth=NonCallableMock(spec_set=["check"]),
state_handler=NonCallableMock(spec_set=["handle_new_event"]),
auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
state_handler=NonCallableMock(spec_set=[
"annotate_state_groups",
]),
config=self.mock_config,
)
self.federation = NonCallableMock(spec_set=[
"handle_new_event",
"send_invite",
"get_state_for_room",
])
@ -106,7 +110,6 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
joined = ["red", "green"]
self.state_handler.handle_new_event.return_value = defer.succeed(True)
self.datastore.get_joined_hosts_for_room.return_value = (
defer.succeed(joined)
)
@ -114,18 +117,29 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
store_id = "store_id_fooo"
self.datastore.persist_event.return_value = defer.succeed(store_id)
self.datastore.get_room_member.return_value = defer.succeed(None)
event.state_events = {
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
user_id="@alice:green",
room_id=room_id,
),
(RoomMemberEvent.TYPE, "@bob:red"): self._create_member(
user_id="@bob:red",
room_id=room_id,
),
(RoomMemberEvent.TYPE, target_user_id): event,
}
# Actual invocation
yield self.room_member_handler.change_membership(event)
self.state_handler.handle_new_event.assert_called_once_with(
event, self.snapshot,
)
self.federation.handle_new_event.assert_called_once_with(
event, self.snapshot,
)
self.assertEquals(
set(["blue", "red", "green"]),
set(["red", "green"]),
set(event.destinations)
)
@ -144,28 +158,19 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
room_id = "!foo:red"
user_id = "@bob:red"
user = self.hs.parse_userid(user_id)
target_user_id = "@bob:red"
content = {"membership": Membership.JOIN}
event = self.hs.get_event_factory().create_event(
etype=RoomMemberEvent.TYPE,
event = self._create_member(
user_id=user_id,
state_key=target_user_id,
room_id=room_id,
membership=Membership.JOIN,
content=content,
)
joined = ["red", "green"]
self.state_handler.handle_new_event.return_value = defer.succeed(True)
def get_joined(*args):
return defer.succeed(joined)
self.datastore.get_joined_hosts_for_room.side_effect = get_joined
store_id = "store_id_fooo"
self.datastore.persist_event.return_value = defer.succeed(store_id)
self.datastore.get_room.return_value = defer.succeed(1) # Not None.
@ -178,12 +183,17 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
join_signal_observer = Mock()
self.distributor.observe("user_joined_room", join_signal_observer)
event.state_events = {
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
user_id="@alice:green",
room_id=room_id,
),
(RoomMemberEvent.TYPE, user_id): event,
}
# Actual invocation
yield self.room_member_handler.change_membership(event)
self.state_handler.handle_new_event.assert_called_once_with(
event, self.snapshot
)
self.federation.handle_new_event.assert_called_once_with(
event, self.snapshot
)
@ -197,138 +207,32 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
event
)
self.notifier.on_new_room_event.assert_called_once_with(
event, extra_users=[user])
event, extra_users=[user]
)
join_signal_observer.assert_called_with(
user=user, room_id=room_id)
user=user, room_id=room_id
)
@defer.inlineCallbacks
def STALE_test_invite_join(self):
room_id = "foo"
user_id = "@bob:red"
target_user_id = "@bob:red"
content = {"membership": Membership.JOIN}
event = self.hs.get_event_factory().create_event(
def _create_member(self, user_id, room_id):
return self.hs.get_event_factory().create_event(
etype=RoomMemberEvent.TYPE,
user_id=user_id,
target_user_id=target_user_id,
state_key=user_id,
room_id=room_id,
membership=Membership.JOIN,
content=content,
content={"membership": Membership.JOIN},
)
joined = ["red", "blue", "green"]
self.state_handler.handle_new_event.return_value = defer.succeed(True)
self.datastore.get_joined_hosts_for_room.return_value = (
defer.succeed(joined)
)
store_id = "store_id_fooo"
self.datastore.store_room_member.return_value = defer.succeed(store_id)
self.datastore.get_room.return_value = defer.succeed(None)
prev_state = NonCallableMock(name="prev_state")
prev_state.membership = Membership.INVITE
prev_state.sender = "@foo:blue"
self.datastore.get_room_member.return_value = defer.succeed(prev_state)
# Actual invocation
yield self.room_member_handler.change_membership(event)
self.datastore.get_room_member.assert_called_once_with(
target_user_id, room_id
)
self.assertTrue(self.federation.handle_new_event.called)
args = self.federation.handle_new_event.call_args[0]
invite_join_event = args[0]
self.assertTrue(InviteJoinEvent.TYPE, invite_join_event.TYPE)
self.assertTrue("blue", invite_join_event.target_host)
self.assertTrue(room_id, invite_join_event.room_id)
self.assertTrue(user_id, invite_join_event.user_id)
self.assertFalse(hasattr(invite_join_event, "state_key"))
self.assertEquals(
set(["blue"]),
set(invite_join_event.destinations)
)
self.federation.get_state_for_room.assert_called_once_with(
"blue", room_id
)
self.assertFalse(self.datastore.store_room_member.called)
self.assertFalse(self.notifier.on_new_room_event.called)
self.assertFalse(self.state_handler.handle_new_event.called)
@defer.inlineCallbacks
def STALE_test_invite_join_public(self):
room_id = "#foo:blue"
user_id = "@bob:red"
target_user_id = "@bob:red"
content = {"membership": Membership.JOIN}
event = self.hs.get_event_factory().create_event(
etype=RoomMemberEvent.TYPE,
user_id=user_id,
target_user_id=target_user_id,
room_id=room_id,
membership=Membership.JOIN,
content=content,
)
joined = ["red", "blue", "green"]
self.state_handler.handle_new_event.return_value = defer.succeed(True)
self.datastore.get_joined_hosts_for_room.return_value = (
defer.succeed(joined)
)
store_id = "store_id_fooo"
self.datastore.store_room_member.return_value = defer.succeed(store_id)
self.datastore.get_room.return_value = defer.succeed(None)
prev_state = NonCallableMock(name="prev_state")
prev_state.membership = Membership.INVITE
prev_state.sender = "@foo:blue"
self.datastore.get_room_member.return_value = defer.succeed(prev_state)
# Actual invocation
yield self.room_member_handler.change_membership(event)
self.assertTrue(self.federation.handle_new_event.called)
args = self.federation.handle_new_event.call_args[0]
invite_join_event = args[0]
self.assertTrue(InviteJoinEvent.TYPE, invite_join_event.TYPE)
self.assertTrue("blue", invite_join_event.target_host)
self.assertTrue("foo", invite_join_event.room_id)
self.assertTrue(user_id, invite_join_event.user_id)
self.assertFalse(hasattr(invite_join_event, "state_key"))
self.assertEquals(
set(["blue"]),
set(invite_join_event.destinations)
)
self.federation.get_state_for_room.assert_called_once_with(
"blue", "foo"
)
self.assertFalse(self.datastore.store_room_member.called)
self.assertFalse(self.notifier.on_new_room_event.called)
self.assertFalse(self.state_handler.handle_new_event.called)
class RoomCreationTest(unittest.TestCase):
def setUp(self):
self.hostname = "red"
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer(
self.hostname,
db_pool=None,
@ -345,12 +249,14 @@ class RoomCreationTest(unittest.TestCase):
"room_member_handler",
"federation_handler",
]),
auth=NonCallableMock(spec_set=["check"]),
state_handler=NonCallableMock(spec_set=["handle_new_event"]),
auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
state_handler=NonCallableMock(spec_set=[
"annotate_state_groups",
]),
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
config=self.mock_config,
)
self.federation = NonCallableMock(spec_set=[
@ -373,6 +279,11 @@ class RoomCreationTest(unittest.TestCase):
])
self.room_member_handler = self.handlers.room_member_handler
def annotate(event):
event.state_events = {}
return defer.succeed(None)
self.state_handler.annotate_state_groups.side_effect = annotate
def hosts(room):
return defer.succeed([])
self.datastore.get_joined_hosts_for_room.side_effect = hosts
@ -400,6 +311,6 @@ class RoomCreationTest(unittest.TestCase):
self.assertEquals(user_id, join_event.user_id)
self.assertEquals(user_id, join_event.state_key)
self.assertTrue(self.state_handler.handle_new_event.called)
self.assertTrue(self.state_handler.annotate_state_groups.called)
self.assertTrue(self.federation.handle_new_event.called)

View File

@ -40,6 +40,7 @@ def _expect_edu(destination, edu_type, content, origin="test"):
"content": content,
}
],
"pdu_failures": [],
}

View File

@ -25,10 +25,7 @@ import synapse.rest.room
from synapse.server import HomeServer
# python imports
import json
from ..utils import MockHttpResource, MemoryDataStore
from ..utils import MockHttpResource, SQLiteMemoryDbPool, MockKey
from .utils import RestTestCase
from mock import Mock, NonCallableMock
@ -49,7 +46,7 @@ class EventStreamPaginationApiTestCase(unittest.TestCase):
def tearDown(self):
pass
def test_long_poll(self):
def TODO_test_long_poll(self):
# stream from 'end' key, send (self+other) message, expect message.
# stream from 'END', send (self+other) message, expect message.
@ -64,7 +61,7 @@ class EventStreamPaginationApiTestCase(unittest.TestCase):
pass
def test_stream_forward(self):
def TODO_test_stream_forward(self):
# stream from START, expect injected items
# stream from 'start' key, expect same content
@ -80,14 +77,14 @@ class EventStreamPaginationApiTestCase(unittest.TestCase):
# returned as end key
pass
def test_limits(self):
def TODO_test_limits(self):
# stream from a key, expect limit_num items
# stream from START, expect limit_num items
pass
def test_range(self):
def TODO_test_range(self):
# stream from key to key, expect X items
# stream from key to END, expect X items
@ -97,7 +94,7 @@ class EventStreamPaginationApiTestCase(unittest.TestCase):
# stream from START to END, expect all items
pass
def test_direction(self):
def TODO_test_direction(self):
# stream from END to START and fwds, expect newest first
# stream from END to START and bwds, expect oldest first
@ -116,19 +113,20 @@ class EventStreamPermissionsTestCase(RestTestCase):
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
state_handler = Mock(spec=["handle_new_event"])
state_handler.handle_new_event.return_value = True
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer(
"test",
db_pool=None,
db_pool=db_pool,
http_client=None,
replication_layer=Mock(),
state_handler=state_handler,
datastore=MemoryDataStore(),
persistence_service=persistence_service,
clock=Mock(spec=[
"call_later",
@ -139,7 +137,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
config=self.mock_config,
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
@ -148,6 +146,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
hs.get_clock().time_msec.return_value = 1000000
hs.get_clock().time.return_value = 1000
synapse.rest.register.register_servlets(hs, self.mock_resource)
synapse.rest.events.register_servlets(hs, self.mock_resource)
@ -172,12 +171,14 @@ class EventStreamPermissionsTestCase(RestTestCase):
def test_stream_basic_permissions(self):
# invalid token, expect 403
(code, response) = yield self.mock_resource.trigger_get(
"/events?access_token=%s" % ("invalid" + self.token))
"/events?access_token=%s" % ("invalid" + self.token, )
)
self.assertEquals(403, code, msg=str(response))
# valid token, expect content
(code, response) = yield self.mock_resource.trigger_get(
"/events?access_token=%s&timeout=0" % (self.token))
"/events?access_token=%s&timeout=0" % (self.token,)
)
self.assertEquals(200, code, msg=str(response))
self.assertTrue("chunk" in response)
self.assertTrue("start" in response)
@ -185,15 +186,23 @@ class EventStreamPermissionsTestCase(RestTestCase):
@defer.inlineCallbacks
def test_stream_room_permissions(self):
room_id = yield self.create_room_as(self.other_user,
tok=self.other_token)
room_id = yield self.create_room_as(
self.other_user,
tok=self.other_token
)
yield self.send(room_id, tok=self.other_token)
# invited to room (expect no content for room)
yield self.invite(room_id, src=self.other_user, targ=self.user_id,
tok=self.other_token)
yield self.invite(
room_id,
src=self.other_user,
targ=self.user_id,
tok=self.other_token
)
(code, response) = yield self.mock_resource.trigger_get(
"/events?access_token=%s&timeout=0" % (self.token))
"/events?access_token=%s&timeout=0" % (self.token,)
)
self.assertEquals(200, code, msg=str(response))
self.assertEquals(0, len(response["chunk"]))
@ -203,7 +212,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
# left to room (expect no content for room)
def test_stream_items(self):
def TODO_test_stream_items(self):
# new user, no content
# join room, expect 1 item (join)

View File

@ -18,9 +18,9 @@
from tests import unittest
from twisted.internet import defer
from mock import Mock
from mock import Mock, NonCallableMock
from ..utils import MockHttpResource
from ..utils import MockHttpResource, MockKey
from synapse.api.errors import SynapseError, AuthError
from synapse.server import HomeServer
@ -41,6 +41,9 @@ class ProfileTestCase(unittest.TestCase):
"set_avatar_url",
])
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer("test",
db_pool=None,
http_client=None,
@ -48,6 +51,7 @@ class ProfileTestCase(unittest.TestCase):
federation=Mock(),
replication_layer=Mock(),
datastore=None,
config=self.mock_config,
)
def _get_user_by_req(request=None):

View File

@ -23,11 +23,14 @@ from synapse.api.constants import Membership
from synapse.server import HomeServer
from tests import unittest
# python imports
import json
import urllib
import types
from ..utils import MockHttpResource, MemoryDataStore
from ..utils import MockHttpResource, SQLiteMemoryDbPool, MockKey
from .utils import RestTestCase
from mock import Mock, NonCallableMock
@ -44,24 +47,21 @@ class RoomPermissionsTestCase(RestTestCase):
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
state_handler = Mock(spec=["handle_new_event"])
state_handler.handle_new_event.return_value = True
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = []
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer(
"red",
db_pool=None,
db_pool=db_pool,
http_client=None,
datastore=MemoryDataStore(),
replication_layer=Mock(),
state_handler=state_handler,
persistence_service=persistence_service,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
config=self.mock_config,
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
@ -76,6 +76,10 @@ class RoomPermissionsTestCase(RestTestCase):
}
hs.get_auth().get_user_by_token = _get_user_by_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
hs.get_datastore().insert_client_ip = _insert_client_ip
self.auth_user_id = self.rmcreator_id
synapse.rest.room.register_servlets(hs, self.mock_resource)
@ -147,38 +151,55 @@ class RoomPermissionsTestCase(RestTestCase):
@defer.inlineCallbacks
def test_send_message(self):
msg_content = '{"msgtype":"m.text","body":"hello"}'
send_msg_path = ("/rooms/%s/send/m.room.message/mid1" %
(self.created_rmid))
send_msg_path = (
"/rooms/%s/send/m.room.message/mid1" % (self.created_rmid,)
)
# send message in uncreated room, expect 403
(code, response) = yield self.mock_resource.trigger(
"PUT",
"/rooms/%s/send/m.room.message/mid2" %
(self.uncreated_rmid), msg_content)
"PUT",
"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
msg_content
)
self.assertEquals(403, code, msg=str(response))
# send message in created room not joined (no state), expect 403
(code, response) = yield self.mock_resource.trigger(
"PUT", send_msg_path, msg_content)
"PUT",
send_msg_path,
msg_content
)
self.assertEquals(403, code, msg=str(response))
# send message in created room and invited, expect 403
yield self.invite(room=self.created_rmid, src=self.rmcreator_id,
targ=self.user_id)
yield self.invite(
room=self.created_rmid,
src=self.rmcreator_id,
targ=self.user_id
)
(code, response) = yield self.mock_resource.trigger(
"PUT", send_msg_path, msg_content)
"PUT",
send_msg_path,
msg_content
)
self.assertEquals(403, code, msg=str(response))
# send message in created room and joined, expect 200
yield self.join(room=self.created_rmid, user=self.user_id)
(code, response) = yield self.mock_resource.trigger(
"PUT", send_msg_path, msg_content)
"PUT",
send_msg_path,
msg_content
)
self.assertEquals(200, code, msg=str(response))
# send message in created room and left, expect 403
yield self.leave(room=self.created_rmid, user=self.user_id)
(code, response) = yield self.mock_resource.trigger(
"PUT", send_msg_path, msg_content)
"PUT",
send_msg_path,
msg_content
)
self.assertEquals(403, code, msg=str(response))
@defer.inlineCallbacks
@ -215,9 +236,14 @@ class RoomPermissionsTestCase(RestTestCase):
# set/get topic in created PRIVATE room and joined, expect 200
yield self.join(room=self.created_rmid, user=self.user_id)
# Only room ops can set topic by default
self.auth_user_id = self.rmcreator_id
(code, response) = yield self.mock_resource.trigger(
"PUT", topic_path, topic_content)
self.assertEquals(200, code, msg=str(response))
self.auth_user_id = self.user_id
(code, response) = yield self.mock_resource.trigger_get(topic_path)
self.assertEquals(200, code, msg=str(response))
self.assert_dict(json.loads(topic_content), response)
@ -381,45 +407,55 @@ class RoomPermissionsTestCase(RestTestCase):
# set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 403s
for usr in [self.user_id, self.rmcreator_id]:
yield self.change_membership(room=room, src=self.user_id,
targ=usr,
membership=Membership.INVITE,
expect_code=403)
yield self.change_membership(room=room, src=self.user_id,
targ=usr,
membership=Membership.JOIN,
expect_code=403)
yield self.change_membership(room=room, src=self.user_id,
targ=usr,
membership=Membership.LEAVE,
expect_code=403)
yield self.change_membership(
room=room,
src=self.user_id,
targ=usr,
membership=Membership.INVITE,
expect_code=403
)
yield self.change_membership(
room=room,
src=self.user_id,
targ=usr,
membership=Membership.JOIN,
expect_code=403
)
# It is always valid to LEAVE if you've already left (currently.)
yield self.change_membership(
room=room,
src=self.user_id,
targ=self.rmcreator_id,
membership=Membership.LEAVE,
expect_code=403
)
class RoomsMemberListTestCase(RestTestCase):
""" Tests /rooms/$room_id/members/list REST events."""
user_id = "@sid1:red"
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
state_handler = Mock(spec=["handle_new_event"])
state_handler.handle_new_event.return_value = True
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = []
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer(
"red",
db_pool=None,
db_pool=db_pool,
http_client=None,
datastore=MemoryDataStore(),
replication_layer=Mock(),
state_handler=state_handler,
persistence_service=persistence_service,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
config=self.mock_config,
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
@ -436,6 +472,10 @@ class RoomsMemberListTestCase(RestTestCase):
}
hs.get_auth().get_user_by_token = _get_user_by_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
hs.get_datastore().insert_client_ip = _insert_client_ip
synapse.rest.room.register_servlets(hs, self.mock_resource)
def tearDown(self):
@ -487,28 +527,26 @@ class RoomsCreateTestCase(RestTestCase):
""" Tests /rooms and /rooms/$room_id REST events. """
user_id = "@sid1:red"
@defer.inlineCallbacks
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self.auth_user_id = self.user_id
state_handler = Mock(spec=["handle_new_event"])
state_handler.handle_new_event.return_value = True
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = []
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer(
"red",
db_pool=None,
db_pool=db_pool,
http_client=None,
datastore=MemoryDataStore(),
replication_layer=Mock(),
state_handler=state_handler,
persistence_service=persistence_service,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
config=self.mock_config,
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
@ -523,6 +561,10 @@ class RoomsCreateTestCase(RestTestCase):
}
hs.get_auth().get_user_by_token = _get_user_by_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
hs.get_datastore().insert_client_ip = _insert_client_ip
synapse.rest.room.register_servlets(hs, self.mock_resource)
def tearDown(self):
@ -592,24 +634,21 @@ class RoomTopicTestCase(RestTestCase):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self.auth_user_id = self.user_id
state_handler = Mock(spec=["handle_new_event"])
state_handler.handle_new_event.return_value = True
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = []
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer(
"red",
db_pool=None,
db_pool=db_pool,
http_client=None,
datastore=MemoryDataStore(),
replication_layer=Mock(),
state_handler=state_handler,
persistence_service=persistence_service,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
config=self.mock_config,
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
@ -622,13 +661,18 @@ class RoomTopicTestCase(RestTestCase):
"admin": False,
"device_id": None,
}
hs.get_auth().get_user_by_token = _get_user_by_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
hs.get_datastore().insert_client_ip = _insert_client_ip
synapse.rest.room.register_servlets(hs, self.mock_resource)
# create the room
self.room_id = yield self.create_room_as(self.user_id)
self.path = "/rooms/%s/state/m.room.topic" % self.room_id
self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,)
def tearDown(self):
pass
@ -706,24 +750,21 @@ class RoomMemberStateTestCase(RestTestCase):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self.auth_user_id = self.user_id
state_handler = Mock(spec=["handle_new_event"])
state_handler.handle_new_event.return_value = True
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = []
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer(
"red",
db_pool=None,
db_pool=db_pool,
http_client=None,
datastore=MemoryDataStore(),
replication_layer=Mock(),
state_handler=state_handler,
persistence_service=persistence_service,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
config=self.mock_config,
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
@ -736,13 +777,12 @@ class RoomMemberStateTestCase(RestTestCase):
"admin": False,
"device_id": None,
}
return {
"user": hs.parse_userid(self.auth_user_id),
"admin": False,
"device_id": None,
}
hs.get_auth().get_user_by_token = _get_user_by_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
hs.get_datastore().insert_client_ip = _insert_client_ip
synapse.rest.room.register_servlets(hs, self.mock_resource)
self.room_id = yield self.create_room_as(self.user_id)
@ -847,24 +887,21 @@ class RoomMessagesTestCase(RestTestCase):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
self.auth_user_id = self.user_id
state_handler = Mock(spec=["handle_new_event"])
state_handler.handle_new_event.return_value = True
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = []
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
hs = HomeServer(
"red",
db_pool=None,
db_pool=db_pool,
http_client=None,
datastore=MemoryDataStore(),
replication_layer=Mock(),
state_handler=state_handler,
persistence_service=persistence_service,
ratelimiter=NonCallableMock(spec_set=[
"send_message",
]),
config=NonCallableMock(),
config=self.mock_config,
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0)
@ -879,6 +916,10 @@ class RoomMessagesTestCase(RestTestCase):
}
hs.get_auth().get_user_by_token = _get_user_by_token
def _insert_client_ip(*args, **kwargs):
return defer.succeed(None)
hs.get_datastore().insert_client_ip = _insert_client_ip
synapse.rest.room.register_servlets(hs, self.mock_resource)
self.room_id = yield self.create_room_as(self.user_id)

View File

@ -74,7 +74,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_select_one_1col(self):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = ("Value",)
self.mock_txn.fetchall.return_value = [("Value",)]
value = yield self.datastore._simple_select_one_onecol(
table="tablename",

View File

@ -61,6 +61,7 @@ class RedactionTestCase(unittest.TestCase):
membership=membership,
content={"membership": membership},
depth=self.depth,
prev_events=[],
)
event.content.update(extra_content)
@ -68,6 +69,11 @@ class RedactionTestCase(unittest.TestCase):
if prev_state:
event.prev_state = prev_state
event.state_events = None
event.hashes = {}
event.prev_state = []
event.auth_events = []
# Have to create a join event using the eventfactory
yield self.store.persist_event(
event
@ -85,8 +91,13 @@ class RedactionTestCase(unittest.TestCase):
room_id=room.to_string(),
content={"body": body, "msgtype": u"message"},
depth=self.depth,
prev_events=[],
)
event.state_events = None
event.hashes = {}
event.auth_events = []
yield self.store.persist_event(
event
)
@ -102,8 +113,13 @@ class RedactionTestCase(unittest.TestCase):
content={"reason": reason},
depth=self.depth,
redacts=event_id,
prev_events=[],
)
event.state_events = None
event.hashes = {}
event.auth_events = []
yield self.store.persist_event(
event
)

View File

@ -127,7 +127,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
def test_room_name(self):
def STALE_test_room_name(self):
name = u"A-Room-Name"
yield self.inject_room_event(
@ -150,7 +150,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
def test_room_name(self):
def STALE_test_room_topic(self):
topic = u"A place for things"
yield self.inject_room_event(

View File

@ -51,16 +51,24 @@ class RoomMemberStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership):
# Have to create a join event using the eventfactory
event = self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
user_id=user.to_string(),
state_key=user.to_string(),
room_id=room.to_string(),
membership=membership,
content={"membership": membership},
depth=1,
prev_events=[],
)
event.state_events = None
event.hashes = {}
event.prev_state = {}
event.auth_events = {}
yield self.store.persist_event(
self.event_factory.create_event(
etype=RoomMemberEvent.TYPE,
user_id=user.to_string(),
state_key=user.to_string(),
room_id=room.to_string(),
membership=membership,
content={"membership": membership},
depth=1,
)
event
)
@defer.inlineCallbacks

View File

@ -48,7 +48,7 @@ class StreamStoreTestCase(unittest.TestCase):
self.depth = 1
@defer.inlineCallbacks
def inject_room_member(self, room, user, membership, prev_state=None):
def inject_room_member(self, room, user, membership, replaces_state=None):
self.depth += 1
event = self.event_factory.create_event(
@ -59,10 +59,17 @@ class StreamStoreTestCase(unittest.TestCase):
membership=membership,
content={"membership": membership},
depth=self.depth,
prev_events=[],
)
if prev_state:
event.prev_state = prev_state
event.state_events = None
event.hashes = {}
event.prev_state = []
event.auth_events = []
if replaces_state:
event.prev_state = [(replaces_state, "hash")]
event.replaces_state = replaces_state
# Have to create a join event using the eventfactory
yield self.store.persist_event(
@ -75,15 +82,22 @@ class StreamStoreTestCase(unittest.TestCase):
def inject_message(self, room, user, body):
self.depth += 1
event = self.event_factory.create_event(
etype=MessageEvent.TYPE,
user_id=user.to_string(),
room_id=room.to_string(),
content={"body": body, "msgtype": u"message"},
depth=self.depth,
prev_events=[],
)
event.state_events = None
event.hashes = {}
event.auth_events = []
# Have to create a join event using the eventfactory
yield self.store.persist_event(
self.event_factory.create_event(
etype=MessageEvent.TYPE,
user_id=user.to_string(),
room_id=room.to_string(),
content={"body": body, "msgtype": u"message"},
depth=self.depth,
)
event
)
@defer.inlineCallbacks
@ -206,7 +220,7 @@ class StreamStoreTestCase(unittest.TestCase):
event2 = yield self.inject_room_member(
self.room1, self.u_alice, Membership.JOIN,
prev_state=event1.event_id,
replaces_state=event1.event_id,
)
end = yield self.store.get_room_events_max_id()
@ -223,4 +237,7 @@ class StreamStoreTestCase(unittest.TestCase):
event = results[0]
self.assertTrue(hasattr(event, "prev_content"), msg="No prev_content key")
self.assertTrue(
hasattr(event, "prev_content"),
msg="No prev_content key"
)

View File

@ -15,599 +15,258 @@
from tests import unittest
from twisted.internet import defer
from twisted.python.log import PythonLoggingObserver
from synapse.state import StateHandler
from synapse.storage.pdu import PduEntry
from synapse.federation.pdu_codec import encode_event_id
from synapse.federation.units import Pdu
from collections import namedtuple
from mock import Mock
import mock
ReturnType = namedtuple(
"StateReturnType", ["new_branch", "current_branch"]
)
def _gen_get_power_level(power_level_list):
def get_power_level(room_id, user_id):
return defer.succeed(power_level_list.get(user_id, None))
return get_power_level
class StateTestCase(unittest.TestCase):
def setUp(self):
self.persistence = Mock(spec=[
"get_unresolved_state_tree",
"update_current_state",
"get_latest_pdus_in_context",
"get_current_state_pdu",
"get_pdu",
"get_power_level",
])
self.replication = Mock(spec=["get_pdu"])
hs = Mock(spec=["get_datastore", "get_replication_layer"])
hs.get_datastore.return_value = self.persistence
hs.get_replication_layer.return_value = self.replication
hs.hostname = "bob.com"
self.store = Mock(
spec_set=[
"get_state_groups",
]
)
hs = Mock(spec=["get_datastore"])
hs.get_datastore.return_value = self.store
self.state = StateHandler(hs)
self.event_id = 0
@defer.inlineCallbacks
def test_new_state_key(self):
# We've never seen anything for this state before
new_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u")
def test_annotate_with_old_message(self):
event = self.create_event(type="test_message", name="event")
self.persistence.get_power_level.side_effect = _gen_get_power_level({})
self.persistence.get_unresolved_state_tree.return_value = (
(ReturnType([new_pdu], []), None)
)
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.persistence.get_unresolved_state_tree.assert_called_once_with(
new_pdu
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
self.assertFalse(self.replication.get_pdu.called)
@defer.inlineCallbacks
def test_direct_overwrite(self):
# We do a direct overwriting of the old state, i.e., the new state
# points to the old state.
old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1")
new_pdu = new_fake_pdu("B", "test", "mem", "x", "A", "u2")
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 5,
})
self.persistence.get_unresolved_state_tree.return_value = (
(ReturnType([new_pdu, old_pdu], [old_pdu]), None)
)
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.persistence.get_unresolved_state_tree.assert_called_once_with(
new_pdu
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
self.assertFalse(self.replication.get_pdu.called)
@defer.inlineCallbacks
def test_overwrite(self):
old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", "A", "u2")
new_pdu = new_fake_pdu("C", "test", "mem", "x", "B", "u3")
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 5,
"u3": 0,
})
self.persistence.get_unresolved_state_tree.return_value = (
(ReturnType([new_pdu, old_pdu_2, old_pdu_1], [old_pdu_1]), None)
)
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.persistence.get_unresolved_state_tree.assert_called_once_with(
new_pdu
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
self.assertFalse(self.replication.get_pdu.called)
@defer.inlineCallbacks
def test_power_level_fail(self):
# We try to update the state based on an outdated state, and have a
# too low power level.
old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 10,
"u3": 5,
})
self.persistence.get_unresolved_state_tree.return_value = (
(ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)
)
is_new = yield self.state.handle_new_state(new_pdu)
self.assertFalse(is_new)
self.persistence.get_unresolved_state_tree.assert_called_once_with(
new_pdu
)
self.assertEqual(0, self.persistence.update_current_state.call_count)
self.assertFalse(self.replication.get_pdu.called)
@defer.inlineCallbacks
def test_power_level_succeed(self):
# We try to update the state based on an outdated state, but have
# sufficient power level to force the update.
old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 10,
"u3": 15,
})
self.persistence.get_unresolved_state_tree.return_value = (
(ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)
)
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.persistence.get_unresolved_state_tree.assert_called_once_with(
new_pdu
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
self.assertFalse(self.replication.get_pdu.called)
@defer.inlineCallbacks
def test_power_level_equal_same_len(self):
# We try to update the state based on an outdated state, the power
# levels are the same and so are the branch lengths
old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 10,
"u3": 10,
})
self.persistence.get_unresolved_state_tree.return_value = (
(ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)
)
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.persistence.get_unresolved_state_tree.assert_called_once_with(
new_pdu
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
self.assertFalse(self.replication.get_pdu.called)
@defer.inlineCallbacks
def test_power_level_equal_diff_len(self):
# We try to update the state based on an outdated state, the power
# levels are the same but the branch length of the new one is longer.
old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
old_pdu_3 = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
new_pdu = new_fake_pdu("D", "test", "mem", "x", "C", "u4")
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 10,
"u3": 10,
"u4": 10,
})
self.persistence.get_unresolved_state_tree.return_value = (
(
ReturnType(
[new_pdu, old_pdu_3, old_pdu_1],
[old_pdu_2, old_pdu_1]
),
None
)
)
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.persistence.get_unresolved_state_tree.assert_called_once_with(
new_pdu
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
self.assertFalse(self.replication.get_pdu.called)
@defer.inlineCallbacks
def test_missing_pdu(self):
# We try to update state against a PDU we haven't yet seen,
# triggering a get_pdu request
# The pdu we haven't seen
old_pdu_1 = new_fake_pdu(
"A", "test", "mem", "x", None, "u1", depth=0
)
old_pdu_2 = new_fake_pdu(
"B", "test", "mem", "x", "A", "u2", depth=1
)
new_pdu = new_fake_pdu(
"C", "test", "mem", "x", "A", "u3", depth=2
)
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 10,
"u3": 20,
})
# The return_value of `get_unresolved_state_tree`, which changes after
# the call to get_pdu
tree_to_return = [(ReturnType([new_pdu], [old_pdu_2]), 0)]
def return_tree(p):
return tree_to_return[0]
def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
tree_to_return[0] = (
ReturnType(
[new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]
),
None
)
return defer.succeed(None)
self.persistence.get_unresolved_state_tree.side_effect = return_tree
self.replication.get_pdu.side_effect = set_return_tree
self.persistence.get_pdu.return_value = None
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.replication.get_pdu.assert_called_with(
destination=new_pdu.origin,
pdu_origin=old_pdu_1.origin,
pdu_id=old_pdu_1.pdu_id,
outlier=True
)
self.persistence.get_unresolved_state_tree.assert_called_with(
new_pdu
)
self.assertEquals(
2, self.persistence.get_unresolved_state_tree.call_count
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
@defer.inlineCallbacks
def test_missing_pdu_depth_1(self):
# We try to update state against a PDU we haven't yet seen,
# triggering a get_pdu request
# The pdu we haven't seen
old_pdu_1 = new_fake_pdu(
"A", "test", "mem", "x", None, "u1", depth=0
)
old_pdu_2 = new_fake_pdu(
"B", "test", "mem", "x", "A", "u2", depth=2
)
old_pdu_3 = new_fake_pdu(
"C", "test", "mem", "x", "B", "u3", depth=3
)
new_pdu = new_fake_pdu(
"D", "test", "mem", "x", "A", "u4", depth=4
)
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 10,
"u3": 10,
"u4": 20,
})
# The return_value of `get_unresolved_state_tree`, which changes after
# the call to get_pdu
tree_to_return = [
(
ReturnType([new_pdu], [old_pdu_3]),
0
),
(
ReturnType(
[new_pdu, old_pdu_1], [old_pdu_3]
),
1
),
(
ReturnType(
[new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1]
),
None
),
old_state = [
self.create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""),
]
to_return = [0]
yield self.state.annotate_state_groups(event, old_state=old_state)
def return_tree(p):
return tree_to_return[to_return[0]]
for k, v in event.old_state_events.items():
type, state_key = k
self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key)
def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
to_return[0] += 1
return defer.succeed(None)
self.assertEqual(set(old_state), set(event.old_state_events.values()))
self.assertDictEqual(event.old_state_events, event.state_events)
self.persistence.get_unresolved_state_tree.side_effect = return_tree
self.replication.get_pdu.side_effect = set_return_tree
self.persistence.get_pdu.return_value = None
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.assertEqual(2, self.replication.get_pdu.call_count)
self.replication.get_pdu.assert_has_calls(
[
mock.call(
destination=new_pdu.origin,
pdu_origin=old_pdu_1.origin,
pdu_id=old_pdu_1.pdu_id,
outlier=True
),
mock.call(
destination=old_pdu_3.origin,
pdu_origin=old_pdu_2.origin,
pdu_id=old_pdu_2.pdu_id,
outlier=True
),
]
)
self.persistence.get_unresolved_state_tree.assert_called_with(
new_pdu
)
self.assertEquals(
3, self.persistence.get_unresolved_state_tree.call_count
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
self.assertIsNone(event.state_group)
@defer.inlineCallbacks
def test_missing_pdu_depth_2(self):
# We try to update state against a PDU we haven't yet seen,
# triggering a get_pdu request
def test_annotate_with_old_state(self):
event = self.create_event(type="state", state_key="", name="event")
# The pdu we haven't seen
old_pdu_1 = new_fake_pdu(
"A", "test", "mem", "x", None, "u1", depth=0
)
old_pdu_2 = new_fake_pdu(
"B", "test", "mem", "x", "A", "u2", depth=2
)
old_pdu_3 = new_fake_pdu(
"C", "test", "mem", "x", "B", "u3", depth=3
)
new_pdu = new_fake_pdu(
"D", "test", "mem", "x", "A", "u4", depth=1
)
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 10,
"u2": 10,
"u3": 10,
"u4": 20,
})
# The return_value of `get_unresolved_state_tree`, which changes after
# the call to get_pdu
tree_to_return = [
(
ReturnType([new_pdu], [old_pdu_3]),
1,
),
(
ReturnType(
[new_pdu], [old_pdu_3, old_pdu_2]
),
0,
),
(
ReturnType(
[new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1]
),
None
),
old_state = [
self.create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""),
]
to_return = [0]
yield self.state.annotate_state_groups(event, old_state=old_state)
def return_tree(p):
return tree_to_return[to_return[0]]
def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
to_return[0] += 1
return defer.succeed(None)
self.persistence.get_unresolved_state_tree.side_effect = return_tree
self.replication.get_pdu.side_effect = set_return_tree
self.persistence.get_pdu.return_value = None
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.assertEqual(2, self.replication.get_pdu.call_count)
self.replication.get_pdu.assert_has_calls(
[
mock.call(
destination=old_pdu_3.origin,
pdu_origin=old_pdu_2.origin,
pdu_id=old_pdu_2.pdu_id,
outlier=True
),
mock.call(
destination=new_pdu.origin,
pdu_origin=old_pdu_1.origin,
pdu_id=old_pdu_1.pdu_id,
outlier=True
),
]
)
self.persistence.get_unresolved_state_tree.assert_called_with(
new_pdu
)
self.assertEquals(
3, self.persistence.get_unresolved_state_tree.call_count
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
@defer.inlineCallbacks
def test_no_common_ancestor(self):
# We do a direct overwriting of the old state, i.e., the new state
# points to the old state.
old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1")
new_pdu = new_fake_pdu("B", "test", "mem", "x", None, "u2")
self.persistence.get_power_level.side_effect = _gen_get_power_level({
"u1": 5,
"u2": 10,
})
self.persistence.get_unresolved_state_tree.return_value = (
(ReturnType([new_pdu], [old_pdu]), None)
)
is_new = yield self.state.handle_new_state(new_pdu)
self.assertTrue(is_new)
self.persistence.get_unresolved_state_tree.assert_called_once_with(
new_pdu
)
self.assertEqual(1, self.persistence.update_current_state.call_count)
self.assertFalse(self.replication.get_pdu.called)
@defer.inlineCallbacks
def test_new_event(self):
event = Mock()
event.event_id = "12123123@test"
state_pdu = new_fake_pdu("C", "test", "mem", "x", "A", 20)
snapshot = Mock()
snapshot.prev_state_pdu = state_pdu
event_id = "pdu_id@origin.com"
def fill_out_prev_events(event):
event.prev_events = [event_id]
event.depth = 6
snapshot.fill_out_prev_events = fill_out_prev_events
yield self.state.handle_new_event(event, snapshot)
self.assertLess(5, event.depth)
self.assertEquals(1, len(event.prev_events))
prev_id = event.prev_events[0]
self.assertEqual(event_id, prev_id)
for k, v in event.old_state_events.items():
type, state_key = k
self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key)
self.assertEqual(
encode_event_id(state_pdu.pdu_id, state_pdu.origin),
event.prev_state
set(old_state + [event]),
set(event.old_state_events.values())
)
self.assertDictEqual(event.old_state_events, event.state_events)
def new_fake_pdu(pdu_id, context, pdu_type, state_key, prev_state_id,
user_id, depth=0):
new_pdu = Pdu(
pdu_id=pdu_id,
pdu_type=pdu_type,
state_key=state_key,
user_id=user_id,
prev_state_id=prev_state_id,
origin="example.com",
context="context",
origin_server_ts=1405353060021,
depth=depth,
content_json="{}",
unrecognized_keys="{}",
outlier=True,
is_state=True,
prev_state_origin="example.com",
have_processed=True,
content={},
)
self.assertIsNone(event.state_group)
return new_pdu
@defer.inlineCallbacks
def test_trivial_annotate_message(self):
event = self.create_event(type="test_message", name="event")
event.prev_events = []
old_state = [
self.create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""),
]
group_name = "group_name_1"
self.store.get_state_groups.return_value = {
group_name: old_state,
}
yield self.state.annotate_state_groups(event)
for k, v in event.old_state_events.items():
type, state_key = k
self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key)
self.assertEqual(
set([e.event_id for e in old_state]),
set([e.event_id for e in event.old_state_events.values()])
)
self.assertDictEqual(
{
k: v.event_id
for k, v in event.old_state_events.items()
},
{
k: v.event_id
for k, v in event.state_events.items()
}
)
self.assertEqual(group_name, event.state_group)
@defer.inlineCallbacks
def test_trivial_annotate_state(self):
event = self.create_event(type="state", state_key="", name="event")
event.prev_events = []
old_state = [
self.create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""),
]
group_name = "group_name_1"
self.store.get_state_groups.return_value = {
group_name: old_state,
}
yield self.state.annotate_state_groups(event)
for k, v in event.old_state_events.items():
type, state_key = k
self.assertEqual(type, v.type)
self.assertEqual(state_key, v.state_key)
self.assertEqual(
set([e.event_id for e in old_state]),
set([e.event_id for e in event.old_state_events.values()])
)
self.assertEqual(
set([e.event_id for e in old_state] + [event.event_id]),
set([e.event_id for e in event.state_events.values()])
)
new_state = {
k: v.event_id
for k, v in event.state_events.items()
}
old_state = {
k: v.event_id
for k, v in event.old_state_events.items()
}
old_state[(event.type, event.state_key)] = event.event_id
self.assertDictEqual(
old_state,
new_state
)
self.assertIsNone(event.state_group)
@defer.inlineCallbacks
def test_resolve_message_conflict(self):
event = self.create_event(type="test_message", name="event")
event.prev_events = []
old_state_1 = [
self.create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""),
]
old_state_2 = [
self.create_event(type="test1", state_key="1"),
self.create_event(type="test3", state_key="2"),
self.create_event(type="test4", state_key=""),
]
group_name_1 = "group_name_1"
group_name_2 = "group_name_2"
self.store.get_state_groups.return_value = {
group_name_1: old_state_1,
group_name_2: old_state_2,
}
yield self.state.annotate_state_groups(event)
self.assertEqual(len(event.old_state_events), 5)
self.assertEqual(
set([e.event_id for e in event.state_events.values()]),
set([e.event_id for e in event.old_state_events.values()])
)
self.assertIsNone(event.state_group)
@defer.inlineCallbacks
def test_resolve_state_conflict(self):
event = self.create_event(type="test4", state_key="", name="event")
event.prev_events = []
old_state_1 = [
self.create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""),
]
old_state_2 = [
self.create_event(type="test1", state_key="1"),
self.create_event(type="test3", state_key="2"),
self.create_event(type="test4", state_key=""),
]
group_name_1 = "group_name_1"
group_name_2 = "group_name_2"
self.store.get_state_groups.return_value = {
group_name_1: old_state_1,
group_name_2: old_state_2,
}
yield self.state.annotate_state_groups(event)
self.assertEqual(len(event.old_state_events), 5)
expected_new = event.old_state_events
expected_new[(event.type, event.state_key)] = event
self.assertEqual(
set([e.event_id for e in expected_new.values()]),
set([e.event_id for e in event.state_events.values()]),
)
self.assertIsNone(event.state_group)
def create_event(self, name=None, type=None, state_key=None):
self.event_id += 1
event_id = str(self.event_id)
if not name:
if state_key is not None:
name = "<%s-%s>" % (type, state_key)
else:
name = "<%s>" % (type, )
event = Mock(name=name, spec=[])
event.type = type
if state_key is not None:
event.state_key = state_key
event.event_id = event_id
event.user_id = "@user_id:example.com"
event.room_id = "!room_id:example.com"
return event

View File

@ -118,13 +118,14 @@ class MockHttpResource(HttpServer):
class MockKey(object):
alg = "mock_alg"
version = "mock_version"
signature = b"\x9a\x87$"
@property
def verify_key(self):
return self
def sign(self, message):
return b"\x9a\x87$"
return self
def verify(self, message, sig):
assert sig == b"\x9a\x87$"