mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-10-01 11:49:51 -04:00
Merge pull request #12 from matrix-org/federation_authorization
Federation authorization
This commit is contained in:
commit
a8ceeec0fd
@ -32,7 +32,7 @@ for port in 8080 8081 8082; do
|
||||
-D --pid-file "$DIR/$port.pid" \
|
||||
--manhole $((port + 1000)) \
|
||||
--tls-dh-params-path "demo/demo.tls.dh" \
|
||||
$PARAMS
|
||||
$PARAMS $SYNAPSE_PARAMS
|
||||
|
||||
python -m synapse.app.homeserver \
|
||||
--config-path "demo/etc/$port.config" \
|
||||
|
@ -1,13 +1,13 @@
|
||||
Signing JSON
|
||||
============
|
||||
|
||||
JSON is signed by encoding the JSON object without ``signatures`` or ``meta``
|
||||
JSON is signed by encoding the JSON object without ``signatures`` or ``unsigned``
|
||||
keys using a canonical encoding. The JSON bytes are then signed using the
|
||||
signature algorithm and the signature encoded using base64 with the padding
|
||||
stripped. The resulting base64 signature is added to an object under the
|
||||
*signing key identifier* which is added to the ``signatures`` object under the
|
||||
name of the server signing it which is added back to the original JSON object
|
||||
along with the ``meta`` object.
|
||||
along with the ``unsigned`` object.
|
||||
|
||||
The *signing key identifier* is the concatenation of the *signing algorithm*
|
||||
and a *key version*. The *signing algorithm* identifies the algorithm used to
|
||||
@ -15,8 +15,8 @@ sign the JSON. The currently support value for *signing algorithm* is
|
||||
``ed25519`` as implemented by NACL (http://nacl.cr.yp.to/). The *key version*
|
||||
is used to distinguish between different signing keys used by the same entity.
|
||||
|
||||
The ``meta`` object and the ``signatures`` object are not covered by the
|
||||
signature. Therefore intermediate servers can add metadata such as time stamps
|
||||
The ``unsigned`` object and the ``signatures`` object are not covered by the
|
||||
signature. Therefore intermediate servers can add unsigneddata such as time stamps
|
||||
and additional signatures.
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ and additional signatures.
|
||||
"signing_keys": {
|
||||
"ed25519:1": "XSl0kuyvrXNj6A+7/tkrB9sxSbRi08Of5uRhxOqZtEQ"
|
||||
},
|
||||
"meta": {
|
||||
"unsigned": {
|
||||
"retrieved_ts_ms": 922834800000
|
||||
},
|
||||
"signatures": {
|
||||
@ -41,7 +41,7 @@ and additional signatures.
|
||||
|
||||
def sign_json(json_object, signing_key, signing_name):
|
||||
signatures = json_object.pop("signatures", {})
|
||||
meta = json_object.pop("meta", None)
|
||||
unsigned = json_object.pop("unsigned", None)
|
||||
|
||||
signed = signing_key.sign(encode_canonical_json(json_object))
|
||||
signature_base64 = encode_base64(signed.signature)
|
||||
@ -50,8 +50,8 @@ and additional signatures.
|
||||
signatures.setdefault(sigature_name, {})[key_id] = signature_base64
|
||||
|
||||
json_object["signatures"] = signatures
|
||||
if meta is not None:
|
||||
json_object["meta"] = meta
|
||||
if unsigned is not None:
|
||||
json_object["unsigned"] = unsigned
|
||||
|
||||
return json_object
|
||||
|
||||
|
47
scripts/check_event_hash.py
Normal file
47
scripts/check_event_hash.py
Normal file
@ -0,0 +1,47 @@
|
||||
from synapse.crypto.event_signing import *
|
||||
from syutil.base64util import encode_base64
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import sys
|
||||
import json
|
||||
|
||||
|
||||
class dictobj(dict):
|
||||
def __init__(self, *args, **kargs):
|
||||
dict.__init__(self, *args, **kargs)
|
||||
self.__dict__ = self
|
||||
|
||||
def get_dict(self):
|
||||
return dict(self)
|
||||
|
||||
def get_full_dict(self):
|
||||
return dict(self)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'),
|
||||
default=sys.stdin)
|
||||
args = parser.parse_args()
|
||||
logging.basicConfig()
|
||||
|
||||
event_json = dictobj(json.load(args.input_json))
|
||||
|
||||
algorithms = {
|
||||
"sha256": hashlib.sha256,
|
||||
}
|
||||
|
||||
for alg_name in event_json.hashes:
|
||||
if check_event_content_hash(event_json, algorithms[alg_name]):
|
||||
print "PASS content hash %s" % (alg_name,)
|
||||
else:
|
||||
print "FAIL content hash %s" % (alg_name,)
|
||||
|
||||
for algorithm in algorithms.values():
|
||||
name, h_bytes = compute_event_reference_hash(event_json, algorithm)
|
||||
print "Reference hash %s: %s" % (name, encode_base64(h_bytes))
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
|
73
scripts/check_signature.py
Normal file
73
scripts/check_signature.py
Normal file
@ -0,0 +1,73 @@
|
||||
|
||||
from syutil.crypto.jsonsign import verify_signed_json
|
||||
from syutil.crypto.signing_key import (
|
||||
decode_verify_key_bytes, write_signing_keys
|
||||
)
|
||||
from syutil.base64util import decode_base64
|
||||
|
||||
import urllib2
|
||||
import json
|
||||
import sys
|
||||
import dns.resolver
|
||||
import pprint
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
def get_targets(server_name):
|
||||
if ":" in server_name:
|
||||
target, port = server_name.split(":")
|
||||
yield (target, int(port))
|
||||
return
|
||||
try:
|
||||
answers = dns.resolver.query("_matrix._tcp." + server_name, "SRV")
|
||||
for srv in answers:
|
||||
yield (srv.target, srv.port)
|
||||
except dns.resolver.NXDOMAIN:
|
||||
yield (server_name, 8480)
|
||||
|
||||
def get_server_keys(server_name, target, port):
|
||||
url = "https://%s:%i/_matrix/key/v1" % (target, port)
|
||||
keys = json.load(urllib2.urlopen(url))
|
||||
verify_keys = {}
|
||||
for key_id, key_base64 in keys["verify_keys"].items():
|
||||
verify_key = decode_verify_key_bytes(key_id, decode_base64(key_base64))
|
||||
verify_signed_json(keys, server_name, verify_key)
|
||||
verify_keys[key_id] = verify_key
|
||||
return verify_keys
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("signature_name")
|
||||
parser.add_argument("input_json", nargs="?", type=argparse.FileType('r'),
|
||||
default=sys.stdin)
|
||||
|
||||
args = parser.parse_args()
|
||||
logging.basicConfig()
|
||||
|
||||
server_name = args.signature_name
|
||||
keys = {}
|
||||
for target, port in get_targets(server_name):
|
||||
try:
|
||||
keys = get_server_keys(server_name, target, port)
|
||||
print "Using keys from https://%s:%s/_matrix/key/v1" % (target, port)
|
||||
write_signing_keys(sys.stdout, keys.values())
|
||||
break
|
||||
except:
|
||||
logging.exception("Error talking to %s:%s", target, port)
|
||||
|
||||
json_to_check = json.load(args.input_json)
|
||||
print "Checking JSON:"
|
||||
for key_id in json_to_check["signatures"][args.signature_name]:
|
||||
try:
|
||||
key = keys[key_id]
|
||||
verify_signed_json(json_to_check, args.signature_name, key)
|
||||
print "PASS %s" % (key_id,)
|
||||
except:
|
||||
logging.exception("Check for key %s failed" % (key_id,))
|
||||
print "FAIL %s" % (key_id,)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
69
scripts/hash_history.py
Normal file
69
scripts/hash_history.py
Normal file
@ -0,0 +1,69 @@
|
||||
from synapse.storage.pdu import PduStore
|
||||
from synapse.storage.signatures import SignatureStore
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.federation.units import Pdu
|
||||
from synapse.crypto.event_signing import (
|
||||
add_event_pdu_content_hash, compute_pdu_event_reference_hash
|
||||
)
|
||||
from synapse.api.events.utils import prune_pdu
|
||||
from syutil.base64util import encode_base64, decode_base64
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
import sqlite3
|
||||
import sys
|
||||
|
||||
class Store(object):
|
||||
_get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"]
|
||||
_get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"]
|
||||
_get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"]
|
||||
_get_pdu_origin_signatures_txn = SignatureStore.__dict__["_get_pdu_origin_signatures_txn"]
|
||||
_store_pdu_content_hash_txn = SignatureStore.__dict__["_store_pdu_content_hash_txn"]
|
||||
_store_pdu_reference_hash_txn = SignatureStore.__dict__["_store_pdu_reference_hash_txn"]
|
||||
_store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"]
|
||||
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
|
||||
|
||||
|
||||
store = Store()
|
||||
|
||||
|
||||
def select_pdus(cursor):
|
||||
cursor.execute(
|
||||
"SELECT pdu_id, origin FROM pdus ORDER BY depth ASC"
|
||||
)
|
||||
|
||||
ids = cursor.fetchall()
|
||||
|
||||
pdu_tuples = store._get_pdu_tuples(cursor, ids)
|
||||
|
||||
pdus = [Pdu.from_pdu_tuple(p) for p in pdu_tuples]
|
||||
|
||||
reference_hashes = {}
|
||||
|
||||
for pdu in pdus:
|
||||
try:
|
||||
if pdu.prev_pdus:
|
||||
print "PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus
|
||||
for pdu_id, origin, hashes in pdu.prev_pdus:
|
||||
ref_alg, ref_hsh = reference_hashes[(pdu_id, origin)]
|
||||
hashes[ref_alg] = encode_base64(ref_hsh)
|
||||
store._store_prev_pdu_hash_txn(cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh)
|
||||
print "SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus
|
||||
pdu = add_event_pdu_content_hash(pdu)
|
||||
ref_alg, ref_hsh = compute_pdu_event_reference_hash(pdu)
|
||||
reference_hashes[(pdu.pdu_id, pdu.origin)] = (ref_alg, ref_hsh)
|
||||
store._store_pdu_reference_hash_txn(cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh)
|
||||
|
||||
for alg, hsh_base64 in pdu.hashes.items():
|
||||
print alg, hsh_base64
|
||||
store._store_pdu_content_hash_txn(cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64))
|
||||
|
||||
except:
|
||||
print "FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus
|
||||
|
||||
def main():
|
||||
conn = sqlite3.connect(sys.argv[1])
|
||||
cursor = conn.cursor()
|
||||
select_pdus(cursor)
|
||||
conn.commit()
|
||||
|
||||
if __name__=='__main__':
|
||||
main()
|
@ -21,8 +21,10 @@ from synapse.api.constants import Membership, JoinRules
|
||||
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
|
||||
from synapse.api.events.room import (
|
||||
RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent,
|
||||
RoomJoinRulesEvent, RoomCreateEvent,
|
||||
)
|
||||
from synapse.util.logutils import log_function
|
||||
from syutil.base64util import encode_base64
|
||||
|
||||
import logging
|
||||
|
||||
@ -35,8 +37,7 @@ class Auth(object):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check(self, event, snapshot, raises=False):
|
||||
def check(self, event, raises=False):
|
||||
""" Checks if this event is correctly authed.
|
||||
|
||||
Returns:
|
||||
@ -47,43 +48,51 @@ class Auth(object):
|
||||
"""
|
||||
try:
|
||||
if hasattr(event, "room_id"):
|
||||
is_state = hasattr(event, "state_key")
|
||||
if event.old_state_events is None:
|
||||
# Oh, we don't know what the state of the room was, so we
|
||||
# are trusting that this is allowed (at least for now)
|
||||
logger.warn("Trusting event: %s", event.event_id)
|
||||
return True
|
||||
|
||||
if hasattr(event, "outlier") and event.outlier is True:
|
||||
# TODO (erikj): Auth for outliers is done differently.
|
||||
return True
|
||||
|
||||
if event.type == RoomCreateEvent.TYPE:
|
||||
# FIXME
|
||||
return True
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
yield self._can_replace_state(event)
|
||||
allowed = yield self.is_membership_change_allowed(event)
|
||||
defer.returnValue(allowed)
|
||||
return
|
||||
allowed = self.is_membership_change_allowed(event)
|
||||
if allowed:
|
||||
logger.debug("Allowing! %s", event)
|
||||
else:
|
||||
logger.debug("Denying! %s", event)
|
||||
return allowed
|
||||
|
||||
self._check_joined_room(
|
||||
member=snapshot.membership_state,
|
||||
user_id=snapshot.user_id,
|
||||
room_id=snapshot.room_id,
|
||||
)
|
||||
|
||||
if is_state:
|
||||
# TODO (erikj): This really only should be called for *new*
|
||||
# state
|
||||
yield self._can_add_state(event)
|
||||
yield self._can_replace_state(event)
|
||||
else:
|
||||
yield self._can_send_event(event)
|
||||
self.check_event_sender_in_room(event)
|
||||
self._can_send_event(event)
|
||||
|
||||
if event.type == RoomPowerLevelsEvent.TYPE:
|
||||
yield self._check_power_levels(event)
|
||||
self._check_power_levels(event)
|
||||
|
||||
if event.type == RoomRedactionEvent.TYPE:
|
||||
yield self._check_redaction(event)
|
||||
self._check_redaction(event)
|
||||
|
||||
defer.returnValue(True)
|
||||
logger.debug("Allowing! %s", event)
|
||||
return True
|
||||
else:
|
||||
raise AuthError(500, "Unknown event: %s" % event)
|
||||
except AuthError as e:
|
||||
logger.info("Event auth check failed on event %s with msg: %s",
|
||||
event, e.msg)
|
||||
logger.info(
|
||||
"Event auth check failed on event %s with msg: %s",
|
||||
event, e.msg
|
||||
)
|
||||
logger.info("Denying! %s", event)
|
||||
if raises:
|
||||
raise e
|
||||
defer.returnValue(False)
|
||||
|
||||
return False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_joined_room(self, room_id, user_id):
|
||||
@ -98,45 +107,80 @@ class Auth(object):
|
||||
pass
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_host_in_room(self, room_id, host):
|
||||
joined_hosts = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
|
||||
defer.returnValue(host in joined_hosts)
|
||||
|
||||
def check_event_sender_in_room(self, event):
|
||||
key = (RoomMemberEvent.TYPE, event.user_id, )
|
||||
member_event = event.state_events.get(key)
|
||||
|
||||
return self._check_joined_room(
|
||||
member_event,
|
||||
event.user_id,
|
||||
event.room_id
|
||||
)
|
||||
|
||||
def _check_joined_room(self, member, user_id, room_id):
|
||||
if not member or member.membership != Membership.JOIN:
|
||||
raise AuthError(403, "User %s not in room %s (%s)" % (
|
||||
user_id, room_id, repr(member)
|
||||
))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def is_membership_change_allowed(self, event):
|
||||
target_user_id = event.state_key
|
||||
|
||||
# does this room even exist
|
||||
room = yield self.store.get_room(event.room_id)
|
||||
if not room:
|
||||
raise AuthError(403, "Room does not exist")
|
||||
|
||||
# get info about the caller
|
||||
try:
|
||||
caller = yield self.store.get_room_member(
|
||||
user_id=event.user_id,
|
||||
room_id=event.room_id)
|
||||
except:
|
||||
caller = None
|
||||
caller_in_room = caller and caller.membership == "join"
|
||||
key = (RoomMemberEvent.TYPE, event.user_id, )
|
||||
caller = event.old_state_events.get(key)
|
||||
|
||||
caller_in_room = caller and caller.membership == Membership.JOIN
|
||||
caller_invited = caller and caller.membership == Membership.INVITE
|
||||
|
||||
# get info about the target
|
||||
try:
|
||||
target = yield self.store.get_room_member(
|
||||
user_id=target_user_id,
|
||||
room_id=event.room_id)
|
||||
except:
|
||||
target = None
|
||||
target_in_room = target and target.membership == "join"
|
||||
key = (RoomMemberEvent.TYPE, target_user_id, )
|
||||
target = event.old_state_events.get(key)
|
||||
|
||||
target_in_room = target and target.membership == Membership.JOIN
|
||||
|
||||
membership = event.content["membership"]
|
||||
|
||||
join_rule = yield self.store.get_room_join_rule(event.room_id)
|
||||
if not join_rule:
|
||||
key = (RoomJoinRulesEvent.TYPE, "", )
|
||||
join_rule_event = event.old_state_events.get(key)
|
||||
if join_rule_event:
|
||||
join_rule = join_rule_event.content.get(
|
||||
"join_rule", JoinRules.INVITE
|
||||
)
|
||||
else:
|
||||
join_rule = JoinRules.INVITE
|
||||
|
||||
user_level = self._get_power_level_from_event_state(
|
||||
event,
|
||||
event.user_id,
|
||||
)
|
||||
|
||||
ban_level, kick_level, redact_level = (
|
||||
self._get_ops_level_from_event_state(
|
||||
event
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"is_membership_change_allowed: %s",
|
||||
{
|
||||
"caller_in_room": caller_in_room,
|
||||
"caller_invited": caller_invited,
|
||||
"target_in_room": target_in_room,
|
||||
"membership": membership,
|
||||
"join_rule": join_rule,
|
||||
"target_user_id": target_user_id,
|
||||
"event.user_id": event.user_id,
|
||||
}
|
||||
)
|
||||
|
||||
if Membership.INVITE == membership:
|
||||
# TODO (erikj): We should probably handle this more intelligently
|
||||
# PRIVATE join rules.
|
||||
@ -153,13 +197,10 @@ class Auth(object):
|
||||
# joined: It's a NOOP
|
||||
if event.user_id != target_user_id:
|
||||
raise AuthError(403, "Cannot force another user to join.")
|
||||
elif join_rule == JoinRules.PUBLIC or room.is_public:
|
||||
elif join_rule == JoinRules.PUBLIC:
|
||||
pass
|
||||
elif join_rule == JoinRules.INVITE:
|
||||
if (
|
||||
not caller or caller.membership not in
|
||||
[Membership.INVITE, Membership.JOIN]
|
||||
):
|
||||
if not caller_in_room and not caller_invited:
|
||||
raise AuthError(403, "You are not invited to this room.")
|
||||
else:
|
||||
# TODO (erikj): may_join list
|
||||
@ -171,29 +212,16 @@ class Auth(object):
|
||||
if not caller_in_room: # trying to leave a room you aren't joined
|
||||
raise AuthError(403, "You are not in room %s." % event.room_id)
|
||||
elif target_user_id != event.user_id:
|
||||
user_level = yield self.store.get_power_level(
|
||||
event.room_id,
|
||||
event.user_id,
|
||||
)
|
||||
_, kick_level, _ = yield self.store.get_ops_levels(event.room_id)
|
||||
|
||||
if kick_level:
|
||||
kick_level = int(kick_level)
|
||||
else:
|
||||
kick_level = 50
|
||||
kick_level = 50 # FIXME (erikj): What should we do here?
|
||||
|
||||
if user_level < kick_level:
|
||||
raise AuthError(
|
||||
403, "You cannot kick user %s." % target_user_id
|
||||
)
|
||||
elif Membership.BAN == membership:
|
||||
user_level = yield self.store.get_power_level(
|
||||
event.room_id,
|
||||
event.user_id,
|
||||
)
|
||||
|
||||
ban_level, _, _ = yield self.store.get_ops_levels(event.room_id)
|
||||
|
||||
if ban_level:
|
||||
ban_level = int(ban_level)
|
||||
else:
|
||||
@ -204,7 +232,30 @@ class Auth(object):
|
||||
else:
|
||||
raise AuthError(500, "Unknown membership %s" % membership)
|
||||
|
||||
defer.returnValue(True)
|
||||
return True
|
||||
|
||||
def _get_power_level_from_event_state(self, event, user_id):
|
||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||
power_level_event = event.old_state_events.get(key)
|
||||
level = None
|
||||
if power_level_event:
|
||||
level = power_level_event.content.get("users", {}).get(user_id)
|
||||
if not level:
|
||||
level = power_level_event.content.get("users_default", 0)
|
||||
|
||||
return level
|
||||
|
||||
def _get_ops_level_from_event_state(self, event):
|
||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||
power_level_event = event.old_state_events.get(key)
|
||||
|
||||
if power_level_event:
|
||||
return (
|
||||
power_level_event.content.get("ban", 50),
|
||||
power_level_event.content.get("kick", 50),
|
||||
power_level_event.content.get("redact", 50),
|
||||
)
|
||||
return None, None, None,
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_by_req(self, request):
|
||||
@ -229,7 +280,7 @@ class Auth(object):
|
||||
default=[""]
|
||||
)[0]
|
||||
if user and access_token and ip_addr:
|
||||
self.store.insert_client_ip(
|
||||
yield self.store.insert_client_ip(
|
||||
user=user,
|
||||
access_token=access_token,
|
||||
device_id=user_info["device_id"],
|
||||
@ -273,17 +324,81 @@ class Auth(object):
|
||||
return self.store.is_server_admin(user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_auth_events(self, event):
|
||||
if event.type == RoomCreateEvent.TYPE:
|
||||
event.auth_events = []
|
||||
return
|
||||
|
||||
auth_events = []
|
||||
|
||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||
power_level_event = event.old_state_events.get(key)
|
||||
|
||||
if power_level_event:
|
||||
auth_events.append(power_level_event.event_id)
|
||||
|
||||
key = (RoomJoinRulesEvent.TYPE, "", )
|
||||
join_rule_event = event.old_state_events.get(key)
|
||||
|
||||
key = (RoomMemberEvent.TYPE, event.user_id, )
|
||||
member_event = event.old_state_events.get(key)
|
||||
|
||||
if join_rule_event:
|
||||
join_rule = join_rule_event.content.get("join_rule")
|
||||
is_public = join_rule == JoinRules.PUBLIC if join_rule else False
|
||||
else:
|
||||
is_public = False
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
e_type = event.content["membership"]
|
||||
if e_type in [Membership.JOIN, Membership.INVITE]:
|
||||
if join_rule_event:
|
||||
auth_events.append(join_rule_event.event_id)
|
||||
|
||||
if member_event and not is_public:
|
||||
auth_events.append(member_event.event_id)
|
||||
elif member_event:
|
||||
if member_event.content["membership"] == Membership.JOIN:
|
||||
auth_events.append(member_event.event_id)
|
||||
|
||||
hashes = yield self.store.get_event_reference_hashes(
|
||||
auth_events
|
||||
)
|
||||
hashes = [
|
||||
{
|
||||
k: encode_base64(v) for k, v in h.items()
|
||||
if k == "sha256"
|
||||
}
|
||||
for h in hashes
|
||||
]
|
||||
event.auth_events = zip(auth_events, hashes)
|
||||
|
||||
@log_function
|
||||
def _can_send_event(self, event):
|
||||
send_level = yield self.store.get_send_event_level(event.room_id)
|
||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||
send_level_event = event.old_state_events.get(key)
|
||||
send_level = None
|
||||
if send_level_event:
|
||||
send_level = send_level_event.content.get("events", {}).get(
|
||||
event.type
|
||||
)
|
||||
if not send_level:
|
||||
if hasattr(event, "state_key"):
|
||||
send_level = send_level_event.content.get(
|
||||
"state_default", 50
|
||||
)
|
||||
else:
|
||||
send_level = send_level_event.content.get(
|
||||
"events_default", 0
|
||||
)
|
||||
|
||||
if send_level:
|
||||
send_level = int(send_level)
|
||||
else:
|
||||
send_level = 0
|
||||
|
||||
user_level = yield self.store.get_power_level(
|
||||
event.room_id,
|
||||
user_level = self._get_power_level_from_event_state(
|
||||
event,
|
||||
event.user_id,
|
||||
)
|
||||
|
||||
@ -294,84 +409,22 @@ class Auth(object):
|
||||
|
||||
if user_level < send_level:
|
||||
raise AuthError(
|
||||
403, "You don't have permission to post to the room"
|
||||
403,
|
||||
"You don't have permission to post that to the room. " +
|
||||
"user_level (%d) < send_level (%d)" % (user_level, send_level)
|
||||
)
|
||||
|
||||
defer.returnValue(True)
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _can_add_state(self, event):
|
||||
add_level = yield self.store.get_add_state_level(event.room_id)
|
||||
|
||||
if not add_level:
|
||||
defer.returnValue(True)
|
||||
|
||||
add_level = int(add_level)
|
||||
|
||||
user_level = yield self.store.get_power_level(
|
||||
event.room_id,
|
||||
event.user_id,
|
||||
)
|
||||
|
||||
user_level = int(user_level)
|
||||
|
||||
if user_level < add_level:
|
||||
raise AuthError(
|
||||
403, "You don't have permission to add state to the room"
|
||||
)
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _can_replace_state(self, event):
|
||||
current_state = yield self.store.get_current_state(
|
||||
event.room_id,
|
||||
event.type,
|
||||
event.state_key,
|
||||
)
|
||||
|
||||
if current_state:
|
||||
current_state = current_state[0]
|
||||
|
||||
user_level = yield self.store.get_power_level(
|
||||
event.room_id,
|
||||
event.user_id,
|
||||
)
|
||||
|
||||
if user_level:
|
||||
user_level = int(user_level)
|
||||
else:
|
||||
user_level = 0
|
||||
|
||||
logger.debug(
|
||||
"Checking power level for %s, %s", event.user_id, user_level
|
||||
)
|
||||
if current_state and hasattr(current_state, "required_power_level"):
|
||||
req = current_state.required_power_level
|
||||
|
||||
logger.debug("Checked power level for %s, %s", event.user_id, req)
|
||||
if user_level < req:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to change that state"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_redaction(self, event):
|
||||
user_level = yield self.store.get_power_level(
|
||||
event.room_id,
|
||||
user_level = self._get_power_level_from_event_state(
|
||||
event,
|
||||
event.user_id,
|
||||
)
|
||||
|
||||
if user_level:
|
||||
user_level = int(user_level)
|
||||
else:
|
||||
user_level = 0
|
||||
|
||||
_, _, redact_level = yield self.store.get_ops_levels(event.room_id)
|
||||
|
||||
if not redact_level:
|
||||
redact_level = 50
|
||||
_, _, redact_level = self._get_ops_level_from_event_state(
|
||||
event
|
||||
)
|
||||
|
||||
if user_level < redact_level:
|
||||
raise AuthError(
|
||||
@ -379,16 +432,10 @@ class Auth(object):
|
||||
"You don't have permission to redact events"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_power_levels(self, event):
|
||||
for k, v in event.content.items():
|
||||
if k == "default":
|
||||
continue
|
||||
|
||||
# FIXME (erikj): We don't want hsob_Ts in content.
|
||||
if k == "hsob_ts":
|
||||
continue
|
||||
|
||||
user_list = event.content.get("users", {})
|
||||
# Validate users
|
||||
for k, v in user_list.items():
|
||||
try:
|
||||
self.hs.parse_userid(k)
|
||||
except:
|
||||
@ -399,80 +446,68 @@ class Auth(object):
|
||||
except:
|
||||
raise SynapseError(400, "Not a valid power level: %s" % (v,))
|
||||
|
||||
current_state = yield self.store.get_current_state(
|
||||
event.room_id,
|
||||
event.type,
|
||||
event.state_key,
|
||||
)
|
||||
key = (event.type, event.state_key, )
|
||||
current_state = event.old_state_events.get(key)
|
||||
|
||||
if not current_state:
|
||||
return
|
||||
else:
|
||||
current_state = current_state[0]
|
||||
|
||||
user_level = yield self.store.get_power_level(
|
||||
event.room_id,
|
||||
user_level = self._get_power_level_from_event_state(
|
||||
event,
|
||||
event.user_id,
|
||||
)
|
||||
|
||||
if user_level:
|
||||
user_level = int(user_level)
|
||||
else:
|
||||
user_level = 0
|
||||
# Check other levels:
|
||||
levels_to_check = [
|
||||
("users_default", []),
|
||||
("events_default", []),
|
||||
("ban", []),
|
||||
("redact", []),
|
||||
("kick", []),
|
||||
]
|
||||
|
||||
old_list = current_state.content
|
||||
old_list = current_state.content.get("users")
|
||||
for user in set(old_list.keys() + user_list.keys()):
|
||||
levels_to_check.append(
|
||||
(user, ["users"])
|
||||
)
|
||||
|
||||
# FIXME (erikj)
|
||||
old_people = {k: v for k, v in old_list.items() if k.startswith("@")}
|
||||
new_people = {
|
||||
k: v for k, v in event.content.items()
|
||||
if k.startswith("@")
|
||||
}
|
||||
old_list = current_state.content.get("events")
|
||||
new_list = event.content.get("events")
|
||||
for ev_id in set(old_list.keys() + new_list.keys()):
|
||||
levels_to_check.append(
|
||||
(ev_id, ["events"])
|
||||
)
|
||||
|
||||
removed = set(old_people.keys()) - set(new_people.keys())
|
||||
added = set(new_people.keys()) - set(old_people.keys())
|
||||
same = set(old_people.keys()) & set(new_people.keys())
|
||||
old_state = current_state.content
|
||||
new_state = event.content
|
||||
|
||||
for r in removed:
|
||||
if int(old_list[r]) > user_level:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to remove user: %s" % (r, )
|
||||
)
|
||||
for level_to_check, dir in levels_to_check:
|
||||
old_loc = old_state
|
||||
for d in dir:
|
||||
old_loc = old_loc.get(d, {})
|
||||
|
||||
for n in added:
|
||||
if int(event.content[n]) > user_level:
|
||||
new_loc = new_state
|
||||
for d in dir:
|
||||
new_loc = new_loc.get(d, {})
|
||||
|
||||
if level_to_check in old_loc:
|
||||
old_level = int(old_loc[level_to_check])
|
||||
else:
|
||||
old_level = None
|
||||
|
||||
if level_to_check in new_loc:
|
||||
new_level = int(new_loc[level_to_check])
|
||||
else:
|
||||
new_level = None
|
||||
|
||||
if new_level is not None and old_level is not None:
|
||||
if new_level == old_level:
|
||||
continue
|
||||
|
||||
if old_level > user_level or new_level > user_level:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to add ops level greater "
|
||||
"than your own"
|
||||
)
|
||||
|
||||
for s in same:
|
||||
if int(event.content[s]) != int(old_list[s]):
|
||||
if int(event.content[s]) > user_level:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to add ops level greater "
|
||||
"than your own"
|
||||
)
|
||||
|
||||
if "default" in old_list:
|
||||
old_default = int(old_list["default"])
|
||||
|
||||
if old_default > user_level:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to add ops level greater than "
|
||||
"your own"
|
||||
)
|
||||
|
||||
if "default" in event.content:
|
||||
new_default = int(event.content["default"])
|
||||
|
||||
if new_default > user_level:
|
||||
raise AuthError(
|
||||
403,
|
||||
"You don't have permission to add ops level greater "
|
||||
"than your own"
|
||||
)
|
||||
|
@ -158,3 +158,37 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
|
||||
for key, value in kwargs.iteritems():
|
||||
err[key] = value
|
||||
return err
|
||||
|
||||
|
||||
class FederationError(RuntimeError):
|
||||
""" This class is used to inform remote home servers about erroneous
|
||||
PDUs they sent us.
|
||||
|
||||
FATAL: The remote server could not interpret the source event.
|
||||
(e.g., it was missing a required field)
|
||||
ERROR: The remote server interpreted the event, but it failed some other
|
||||
check (e.g. auth)
|
||||
WARN: The remote server accepted the event, but believes some part of it
|
||||
is wrong (e.g., it referred to an invalid event)
|
||||
"""
|
||||
|
||||
def __init__(self, level, code, reason, affected, source=None):
|
||||
if level not in ["FATAL", "ERROR", "WARN"]:
|
||||
raise ValueError("Level is not valid: %s" % (level,))
|
||||
self.level = level
|
||||
self.code = code
|
||||
self.reason = reason
|
||||
self.affected = affected
|
||||
self.source = source
|
||||
|
||||
msg = "%s %s: %s" % (level, code, reason,)
|
||||
super(FederationError, self).__init__(msg)
|
||||
|
||||
def get_dict(self):
|
||||
return {
|
||||
"level": self.level,
|
||||
"code": self.code,
|
||||
"reason": self.reason,
|
||||
"affected": self.affected,
|
||||
"source": self.source if self.source else self.affected,
|
||||
}
|
||||
|
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.util.jsonobject import JsonEncodedObject
|
||||
|
||||
|
||||
@ -56,22 +55,26 @@ class SynapseEvent(JsonEncodedObject):
|
||||
"user_id", # sender/initiator
|
||||
"content", # HTTP body, JSON
|
||||
"state_key",
|
||||
"required_power_level",
|
||||
"age_ts",
|
||||
"prev_content",
|
||||
"prev_state",
|
||||
"replaces_state",
|
||||
"redacted_because",
|
||||
"origin_server_ts",
|
||||
]
|
||||
|
||||
internal_keys = [
|
||||
"is_state",
|
||||
"prev_events",
|
||||
"depth",
|
||||
"destinations",
|
||||
"origin",
|
||||
"outlier",
|
||||
"power_level",
|
||||
"redacted",
|
||||
"prev_events",
|
||||
"hashes",
|
||||
"signatures",
|
||||
"prev_state",
|
||||
"auth_events",
|
||||
"state_hash",
|
||||
]
|
||||
|
||||
required_keys = [
|
||||
@ -82,8 +85,8 @@ class SynapseEvent(JsonEncodedObject):
|
||||
|
||||
def __init__(self, raises=True, **kwargs):
|
||||
super(SynapseEvent, self).__init__(**kwargs)
|
||||
if "content" in kwargs:
|
||||
self.check_json(self.content, raises=raises)
|
||||
# if "content" in kwargs:
|
||||
# self.check_json(self.content, raises=raises)
|
||||
|
||||
def get_content_template(self):
|
||||
""" Retrieve the JSON template for this event as a dict.
|
||||
@ -114,66 +117,6 @@ class SynapseEvent(JsonEncodedObject):
|
||||
"""
|
||||
raise NotImplementedError("get_content_template not implemented.")
|
||||
|
||||
def check_json(self, content, raises=True):
|
||||
"""Checks the given JSON content abides by the rules of the template.
|
||||
|
||||
Args:
|
||||
content : A JSON object to check.
|
||||
raises: True to raise a SynapseError if the check fails.
|
||||
Returns:
|
||||
True if the content passes the template. Returns False if the check
|
||||
fails and raises=False.
|
||||
Raises:
|
||||
SynapseError if the check fails and raises=True.
|
||||
"""
|
||||
# recursively call to inspect each layer
|
||||
err_msg = self._check_json(content, self.get_content_template())
|
||||
if err_msg:
|
||||
if raises:
|
||||
raise SynapseError(400, err_msg, Codes.BAD_JSON)
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _check_json(self, content, template):
|
||||
"""Check content and template matches.
|
||||
|
||||
If the template is a dict, each key in the dict will be validated with
|
||||
the content, else it will just compare the types of content and
|
||||
template. This basic type check is required because this function will
|
||||
be recursively called and could be called with just strs or ints.
|
||||
|
||||
Args:
|
||||
content: The content to validate.
|
||||
template: The validation template.
|
||||
Returns:
|
||||
str: An error message if the validation fails, else None.
|
||||
"""
|
||||
if type(content) != type(template):
|
||||
return "Mismatched types: %s" % template
|
||||
|
||||
if type(template) == dict:
|
||||
for key in template:
|
||||
if key not in content:
|
||||
return "Missing %s key" % key
|
||||
|
||||
if type(content[key]) != type(template[key]):
|
||||
return "Key %s is of the wrong type (got %s, want %s)" % (
|
||||
key, type(content[key]), type(template[key]))
|
||||
|
||||
if type(content[key]) == dict:
|
||||
# we must go deeper
|
||||
msg = self._check_json(content[key], template[key])
|
||||
if msg:
|
||||
return msg
|
||||
elif type(content[key]) == list:
|
||||
# make sure each item type in content matches the template
|
||||
for entry in content[key]:
|
||||
msg = self._check_json(entry, template[key][0])
|
||||
if msg:
|
||||
return msg
|
||||
|
||||
|
||||
class SynapseStateEvent(SynapseEvent):
|
||||
|
||||
|
@ -16,11 +16,13 @@
|
||||
from synapse.api.events.room import (
|
||||
RoomTopicEvent, MessageEvent, RoomMemberEvent, FeedbackEvent,
|
||||
InviteJoinEvent, RoomConfigEvent, RoomNameEvent, GenericEvent,
|
||||
RoomPowerLevelsEvent, RoomJoinRulesEvent, RoomOpsPowerLevelsEvent,
|
||||
RoomCreateEvent, RoomAddStateLevelEvent, RoomSendEventLevelEvent,
|
||||
RoomPowerLevelsEvent, RoomJoinRulesEvent,
|
||||
RoomCreateEvent,
|
||||
RoomRedactionEvent,
|
||||
)
|
||||
|
||||
from synapse.types import EventID
|
||||
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
|
||||
@ -37,9 +39,6 @@ class EventFactory(object):
|
||||
RoomPowerLevelsEvent,
|
||||
RoomJoinRulesEvent,
|
||||
RoomCreateEvent,
|
||||
RoomAddStateLevelEvent,
|
||||
RoomSendEventLevelEvent,
|
||||
RoomOpsPowerLevelsEvent,
|
||||
RoomRedactionEvent,
|
||||
]
|
||||
|
||||
@ -51,12 +50,26 @@ class EventFactory(object):
|
||||
self.clock = hs.get_clock()
|
||||
self.hs = hs
|
||||
|
||||
self.event_id_count = 0
|
||||
|
||||
def create_event_id(self):
|
||||
i = str(self.event_id_count)
|
||||
self.event_id_count += 1
|
||||
|
||||
local_part = str(int(self.clock.time())) + i + random_string(5)
|
||||
|
||||
e_id = EventID.create_local(local_part, self.hs)
|
||||
|
||||
return e_id.to_string()
|
||||
|
||||
def create_event(self, etype=None, **kwargs):
|
||||
kwargs["type"] = etype
|
||||
if "event_id" not in kwargs:
|
||||
kwargs["event_id"] = "%s@%s" % (
|
||||
random_string(10), self.hs.hostname
|
||||
)
|
||||
kwargs["event_id"] = self.create_event_id()
|
||||
kwargs["origin"] = self.hs.hostname
|
||||
else:
|
||||
ev_id = self.hs.parse_eventid(kwargs["event_id"])
|
||||
kwargs["origin"] = ev_id.domain
|
||||
|
||||
if "origin_server_ts" not in kwargs:
|
||||
kwargs["origin_server_ts"] = int(self.clock.time_msec())
|
||||
|
@ -154,27 +154,6 @@ class RoomPowerLevelsEvent(SynapseStateEvent):
|
||||
return {}
|
||||
|
||||
|
||||
class RoomAddStateLevelEvent(SynapseStateEvent):
|
||||
TYPE = "m.room.add_state_level"
|
||||
|
||||
def get_content_template(self):
|
||||
return {}
|
||||
|
||||
|
||||
class RoomSendEventLevelEvent(SynapseStateEvent):
|
||||
TYPE = "m.room.send_event_level"
|
||||
|
||||
def get_content_template(self):
|
||||
return {}
|
||||
|
||||
|
||||
class RoomOpsPowerLevelsEvent(SynapseStateEvent):
|
||||
TYPE = "m.room.ops_levels"
|
||||
|
||||
def get_content_template(self):
|
||||
return {}
|
||||
|
||||
|
||||
class RoomAliasesEvent(SynapseStateEvent):
|
||||
TYPE = "m.room.aliases"
|
||||
|
||||
|
@ -15,21 +15,34 @@
|
||||
|
||||
from .room import (
|
||||
RoomMemberEvent, RoomJoinRulesEvent, RoomPowerLevelsEvent,
|
||||
RoomAddStateLevelEvent, RoomSendEventLevelEvent, RoomOpsPowerLevelsEvent,
|
||||
RoomAliasesEvent, RoomCreateEvent,
|
||||
)
|
||||
|
||||
|
||||
def prune_event(event):
|
||||
""" Prunes the given event of all keys we don't know about or think could
|
||||
potentially be dodgy.
|
||||
""" Returns a pruned version of the given event, which removes all keys we
|
||||
don't know about or think could potentially be dodgy.
|
||||
|
||||
This is used when we "redact" an event. We want to remove all fields that
|
||||
the user has specified, but we do want to keep necessary information like
|
||||
type, state_key etc.
|
||||
"""
|
||||
event_type = event.type
|
||||
|
||||
# Remove all extraneous fields.
|
||||
event.unrecognized_keys = {}
|
||||
allowed_keys = [
|
||||
"event_id",
|
||||
"user_id",
|
||||
"room_id",
|
||||
"hashes",
|
||||
"signatures",
|
||||
"content",
|
||||
"type",
|
||||
"state_key",
|
||||
"depth",
|
||||
"prev_events",
|
||||
"prev_state",
|
||||
"auth_events",
|
||||
]
|
||||
|
||||
new_content = {}
|
||||
|
||||
@ -38,27 +51,33 @@ def prune_event(event):
|
||||
if field in event.content:
|
||||
new_content[field] = event.content[field]
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
if event_type == RoomMemberEvent.TYPE:
|
||||
add_fields("membership")
|
||||
elif event.type == RoomCreateEvent.TYPE:
|
||||
elif event_type == RoomCreateEvent.TYPE:
|
||||
add_fields("creator")
|
||||
elif event.type == RoomJoinRulesEvent.TYPE:
|
||||
elif event_type == RoomJoinRulesEvent.TYPE:
|
||||
add_fields("join_rule")
|
||||
elif event.type == RoomPowerLevelsEvent.TYPE:
|
||||
# TODO: Actually check these are valid user_ids etc.
|
||||
add_fields("default")
|
||||
for k, v in event.content.items():
|
||||
if k.startswith("@") and isinstance(v, (int, long)):
|
||||
new_content[k] = v
|
||||
elif event.type == RoomAddStateLevelEvent.TYPE:
|
||||
add_fields("level")
|
||||
elif event.type == RoomSendEventLevelEvent.TYPE:
|
||||
add_fields("level")
|
||||
elif event.type == RoomOpsPowerLevelsEvent.TYPE:
|
||||
add_fields("kick_level", "ban_level", "redact_level")
|
||||
elif event.type == RoomAliasesEvent.TYPE:
|
||||
elif event_type == RoomPowerLevelsEvent.TYPE:
|
||||
add_fields(
|
||||
"users",
|
||||
"users_default",
|
||||
"events",
|
||||
"events_default",
|
||||
"events_default",
|
||||
"state_default",
|
||||
"ban",
|
||||
"kick",
|
||||
"redact",
|
||||
)
|
||||
elif event_type == RoomAliasesEvent.TYPE:
|
||||
add_fields("aliases")
|
||||
|
||||
event.content = new_content
|
||||
allowed_fields = {
|
||||
k: v
|
||||
for k, v in event.get_full_dict().items()
|
||||
if k in allowed_keys
|
||||
}
|
||||
|
||||
return event
|
||||
allowed_fields["content"] = new_content
|
||||
|
||||
return type(event)(**allowed_fields)
|
||||
|
87
synapse/api/events/validator.py
Normal file
87
synapse/api/events/validator.py
Normal file
@ -0,0 +1,87 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
|
||||
|
||||
class EventValidator(object):
|
||||
def __init__(self, hs):
|
||||
pass
|
||||
|
||||
def validate(self, event):
|
||||
"""Checks the given JSON content abides by the rules of the template.
|
||||
|
||||
Args:
|
||||
content : A JSON object to check.
|
||||
raises: True to raise a SynapseError if the check fails.
|
||||
Returns:
|
||||
True if the content passes the template. Returns False if the check
|
||||
fails and raises=False.
|
||||
Raises:
|
||||
SynapseError if the check fails and raises=True.
|
||||
"""
|
||||
# recursively call to inspect each layer
|
||||
err_msg = self._check_json_template(
|
||||
event.content,
|
||||
event.get_content_template()
|
||||
)
|
||||
if err_msg:
|
||||
raise SynapseError(400, err_msg, Codes.BAD_JSON)
|
||||
else:
|
||||
return True
|
||||
|
||||
def _check_json_template(self, content, template):
|
||||
"""Check content and template matches.
|
||||
|
||||
If the template is a dict, each key in the dict will be validated with
|
||||
the content, else it will just compare the types of content and
|
||||
template. This basic type check is required because this function will
|
||||
be recursively called and could be called with just strs or ints.
|
||||
|
||||
Args:
|
||||
content: The content to validate.
|
||||
template: The validation template.
|
||||
Returns:
|
||||
str: An error message if the validation fails, else None.
|
||||
"""
|
||||
if type(content) != type(template):
|
||||
return "Mismatched types: %s" % template
|
||||
|
||||
if type(template) == dict:
|
||||
for key in template:
|
||||
if key not in content:
|
||||
return "Missing %s key" % key
|
||||
|
||||
if type(content[key]) != type(template[key]):
|
||||
return "Key %s is of the wrong type (got %s, want %s)" % (
|
||||
key, type(content[key]), type(template[key]))
|
||||
|
||||
if type(content[key]) == dict:
|
||||
# we must go deeper
|
||||
msg = self._check_json_template(
|
||||
content[key],
|
||||
template[key]
|
||||
)
|
||||
if msg:
|
||||
return msg
|
||||
elif type(content[key]) == list:
|
||||
# make sure each item type in content matches the template
|
||||
for entry in content[key]:
|
||||
msg = self._check_json_template(
|
||||
entry,
|
||||
template[key][0]
|
||||
)
|
||||
if msg:
|
||||
return msg
|
@ -236,7 +236,10 @@ def setup():
|
||||
f.namespace['hs'] = hs
|
||||
reactor.listenTCP(config.manhole, f, interface='127.0.0.1')
|
||||
|
||||
hs.start_listening(config.bind_port, config.unsecure_port)
|
||||
bind_port = config.bind_port
|
||||
if config.no_tls:
|
||||
bind_port = None
|
||||
hs.start_listening(bind_port, config.unsecure_port)
|
||||
|
||||
if config.daemonize:
|
||||
print config.pid_file
|
||||
|
@ -30,6 +30,7 @@ class ServerConfig(Config):
|
||||
self.pid_file = self.abspath(args.pid_file)
|
||||
self.webclient = True
|
||||
self.manhole = args.manhole
|
||||
self.no_tls = args.no_tls
|
||||
|
||||
if not args.content_addr:
|
||||
host = args.server_name
|
||||
@ -67,6 +68,8 @@ class ServerConfig(Config):
|
||||
server_group.add_argument("--content-addr", default=None,
|
||||
help="The host and scheme to use for the "
|
||||
"content repository")
|
||||
server_group.add_argument("--no-tls", action='store_true',
|
||||
help="Don't bind to the https port.")
|
||||
|
||||
def read_signing_key(self, signing_key_path):
|
||||
signing_keys = self.read_file(signing_key_path, "signing_key")
|
||||
|
98
synapse/crypto/event_signing.py
Normal file
98
synapse/crypto/event_signing.py
Normal file
@ -0,0 +1,98 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from synapse.api.events.utils import prune_event
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
from syutil.base64util import encode_base64, decode_base64
|
||||
from syutil.crypto.jsonsign import sign_json
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
|
||||
"""Check whether the hash for this PDU matches the contents"""
|
||||
computed_hash = _compute_content_hash(event, hash_algorithm)
|
||||
if computed_hash.name not in event.hashes:
|
||||
raise Exception("Algorithm %s not in hashes %s" % (
|
||||
computed_hash.name, list(event.hashes)
|
||||
))
|
||||
message_hash_base64 = event.hashes[computed_hash.name]
|
||||
try:
|
||||
message_hash_bytes = decode_base64(message_hash_base64)
|
||||
except:
|
||||
raise Exception("Invalid base64: %s" % (message_hash_base64,))
|
||||
return message_hash_bytes == computed_hash.digest()
|
||||
|
||||
|
||||
def _compute_content_hash(event, hash_algorithm):
|
||||
event_json = event.get_full_dict()
|
||||
# TODO: We need to sign the JSON that is going out via fedaration.
|
||||
event_json.pop("age_ts", None)
|
||||
event_json.pop("unsigned", None)
|
||||
event_json.pop("signatures", None)
|
||||
event_json.pop("hashes", None)
|
||||
event_json_bytes = encode_canonical_json(event_json)
|
||||
return hash_algorithm(event_json_bytes)
|
||||
|
||||
|
||||
def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
|
||||
tmp_event = prune_event(event)
|
||||
event_json = tmp_event.get_dict()
|
||||
event_json.pop("signatures", None)
|
||||
event_json.pop("age_ts", None)
|
||||
event_json.pop("unsigned", None)
|
||||
event_json_bytes = encode_canonical_json(event_json)
|
||||
hashed = hash_algorithm(event_json_bytes)
|
||||
return (hashed.name, hashed.digest())
|
||||
|
||||
|
||||
def compute_event_signature(event, signature_name, signing_key):
|
||||
tmp_event = prune_event(event)
|
||||
redact_json = tmp_event.get_full_dict()
|
||||
redact_json.pop("signatures", None)
|
||||
redact_json.pop("age_ts", None)
|
||||
redact_json.pop("unsigned", None)
|
||||
logger.debug("Signing event: %s", redact_json)
|
||||
redact_json = sign_json(redact_json, signature_name, signing_key)
|
||||
return redact_json["signatures"]
|
||||
|
||||
|
||||
def add_hashes_and_signatures(event, signature_name, signing_key,
|
||||
hash_algorithm=hashlib.sha256):
|
||||
if hasattr(event, "old_state_events"):
|
||||
state_json_bytes = encode_canonical_json(
|
||||
[e.event_id for e in event.old_state_events.values()]
|
||||
)
|
||||
hashed = hash_algorithm(state_json_bytes)
|
||||
event.state_hash = {
|
||||
hashed.name: encode_base64(hashed.digest())
|
||||
}
|
||||
|
||||
hashed = _compute_content_hash(event, hash_algorithm=hash_algorithm)
|
||||
|
||||
if not hasattr(event, "hashes"):
|
||||
event.hashes = {}
|
||||
event.hashes[hashed.name] = encode_base64(hashed.digest())
|
||||
|
||||
event.signatures = compute_event_signature(
|
||||
event,
|
||||
signature_name=signature_name,
|
||||
signing_key=signing_key,
|
||||
)
|
@ -18,50 +18,25 @@ from .units import Pdu
|
||||
import copy
|
||||
|
||||
|
||||
def decode_event_id(event_id, server_name):
|
||||
parts = event_id.split("@")
|
||||
if len(parts) < 2:
|
||||
return (event_id, server_name)
|
||||
else:
|
||||
return (parts[0], "".join(parts[1:]))
|
||||
|
||||
|
||||
def encode_event_id(pdu_id, origin):
|
||||
return "%s@%s" % (pdu_id, origin)
|
||||
|
||||
|
||||
class PduCodec(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.server_name = hs.hostname
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self.clock = hs.get_clock()
|
||||
self.hs = hs
|
||||
|
||||
def event_from_pdu(self, pdu):
|
||||
kwargs = {}
|
||||
|
||||
kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
|
||||
kwargs["room_id"] = pdu.context
|
||||
kwargs["etype"] = pdu.pdu_type
|
||||
kwargs["prev_events"] = [
|
||||
encode_event_id(p[0], p[1]) for p in pdu.prev_pdus
|
||||
]
|
||||
|
||||
if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
|
||||
kwargs["prev_state"] = encode_event_id(
|
||||
pdu.prev_state_id, pdu.prev_state_origin
|
||||
)
|
||||
kwargs["etype"] = pdu.type
|
||||
|
||||
kwargs.update({
|
||||
k: v
|
||||
for k, v in pdu.get_full_dict().items()
|
||||
if k not in [
|
||||
"pdu_id",
|
||||
"context",
|
||||
"pdu_type",
|
||||
"prev_pdus",
|
||||
"prev_state_id",
|
||||
"prev_state_origin",
|
||||
"type",
|
||||
]
|
||||
})
|
||||
|
||||
@ -70,33 +45,10 @@ class PduCodec(object):
|
||||
def pdu_from_event(self, event):
|
||||
d = event.get_full_dict()
|
||||
|
||||
d["pdu_id"], d["origin"] = decode_event_id(
|
||||
event.event_id, self.server_name
|
||||
)
|
||||
d["context"] = event.room_id
|
||||
d["pdu_type"] = event.type
|
||||
|
||||
if hasattr(event, "prev_events"):
|
||||
d["prev_pdus"] = [
|
||||
decode_event_id(e, self.server_name)
|
||||
for e in event.prev_events
|
||||
]
|
||||
|
||||
if hasattr(event, "prev_state"):
|
||||
d["prev_state_id"], d["prev_state_origin"] = (
|
||||
decode_event_id(event.prev_state, self.server_name)
|
||||
)
|
||||
|
||||
if hasattr(event, "state_key"):
|
||||
d["is_state"] = True
|
||||
|
||||
kwargs = copy.deepcopy(event.unrecognized_keys)
|
||||
kwargs.update({
|
||||
k: v for k, v in d.items()
|
||||
if k not in ["event_id", "room_id", "type", "prev_events"]
|
||||
})
|
||||
|
||||
if "origin_server_ts" not in kwargs:
|
||||
kwargs["origin_server_ts"] = int(self.clock.time_msec())
|
||||
|
||||
return Pdu(**kwargs)
|
||||
pdu = Pdu(**kwargs)
|
||||
return pdu
|
||||
|
@ -21,8 +21,6 @@ These actions are mostly only used by the :py:mod:`.replication` module.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from .units import Pdu
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
import json
|
||||
@ -32,76 +30,6 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PduActions(object):
|
||||
""" Defines persistence actions that relate to handling PDUs.
|
||||
"""
|
||||
|
||||
def __init__(self, datastore):
|
||||
self.store = datastore
|
||||
|
||||
@log_function
|
||||
def mark_as_processed(self, pdu):
|
||||
""" Persist the fact that we have fully processed the given `Pdu`
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
return self.store.mark_pdu_as_processed(pdu.pdu_id, pdu.origin)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def after_transaction(self, transaction_id, destination, origin):
|
||||
""" Returns all `Pdu`s that we sent to the given remote home server
|
||||
after a given transaction id.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a list of `Pdu`s
|
||||
"""
|
||||
results = yield self.store.get_pdus_after_transaction(
|
||||
transaction_id,
|
||||
destination
|
||||
)
|
||||
|
||||
defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_all_pdus_from_context(self, context):
|
||||
results = yield self.store.get_all_pdus_from_context(context)
|
||||
defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def backfill(self, context, pdu_list, limit):
|
||||
""" For a given list of PDU id and origins return the proceeding
|
||||
`limit` `Pdu`s in the given `context`.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a list of `Pdu`s.
|
||||
"""
|
||||
results = yield self.store.get_backfill(
|
||||
context, pdu_list, limit
|
||||
)
|
||||
|
||||
defer.returnValue([Pdu.from_pdu_tuple(p) for p in results])
|
||||
|
||||
@log_function
|
||||
def is_new(self, pdu):
|
||||
""" When we receive a `Pdu` from a remote home server, we want to
|
||||
figure out whether it is `new`, i.e. it is not some historic PDU that
|
||||
we haven't seen simply because we haven't backfilled back that far.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a `bool`
|
||||
"""
|
||||
return self.store.is_pdu_new(
|
||||
pdu_id=pdu.pdu_id,
|
||||
origin=pdu.origin,
|
||||
context=pdu.context,
|
||||
depth=pdu.depth
|
||||
)
|
||||
|
||||
|
||||
class TransactionActions(object):
|
||||
""" Defines persistence actions that relate to handling Transactions.
|
||||
"""
|
||||
@ -158,7 +86,6 @@ class TransactionActions(object):
|
||||
transaction.transaction_id,
|
||||
transaction.destination,
|
||||
transaction.origin_server_ts,
|
||||
[(p["pdu_id"], p["origin"]) for p in transaction.pdus]
|
||||
)
|
||||
|
||||
@log_function
|
||||
|
@ -21,7 +21,7 @@ from twisted.internet import defer
|
||||
|
||||
from .units import Transaction, Pdu, Edu
|
||||
|
||||
from .persistence import PduActions, TransactionActions
|
||||
from .persistence import TransactionActions
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
@ -57,7 +57,7 @@ class ReplicationLayer(object):
|
||||
self.transport_layer.register_request_handler(self)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.pdu_actions = PduActions(self.store)
|
||||
# self.pdu_actions = PduActions(self.store)
|
||||
self.transaction_actions = TransactionActions(self.store)
|
||||
|
||||
self._transaction_queue = _TransactionQueue(
|
||||
@ -81,7 +81,7 @@ class ReplicationLayer(object):
|
||||
|
||||
def register_edu_handler(self, edu_type, handler):
|
||||
if edu_type in self.edu_handlers:
|
||||
raise KeyError("Already have an EDU handler for %s" % (edu_type))
|
||||
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
|
||||
|
||||
self.edu_handlers[edu_type] = handler
|
||||
|
||||
@ -102,24 +102,17 @@ class ReplicationLayer(object):
|
||||
object to encode as JSON.
|
||||
"""
|
||||
if query_type in self.query_handlers:
|
||||
raise KeyError("Already have a Query handler for %s" % (query_type))
|
||||
raise KeyError(
|
||||
"Already have a Query handler for %s" % (query_type,)
|
||||
)
|
||||
|
||||
self.query_handlers[query_type] = handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_pdu(self, pdu):
|
||||
"""Informs the replication layer about a new PDU generated within the
|
||||
home server that should be transmitted to others.
|
||||
|
||||
This will fill out various attributes on the PDU object, e.g. the
|
||||
`prev_pdus` key.
|
||||
|
||||
*Note:* The home server should always call `send_pdu` even if it knows
|
||||
that it does not need to be replicated to other home servers. This is
|
||||
in case e.g. someone else joins via a remote home server and then
|
||||
backfills.
|
||||
|
||||
TODO: Figure out when we should actually resolve the deferred.
|
||||
|
||||
Args:
|
||||
@ -132,18 +125,15 @@ class ReplicationLayer(object):
|
||||
order = self._order
|
||||
self._order += 1
|
||||
|
||||
logger.debug("[%s] Persisting PDU", pdu.pdu_id)
|
||||
|
||||
# Save *before* trying to send
|
||||
yield self.store.persist_event(pdu=pdu)
|
||||
|
||||
logger.debug("[%s] Persisted PDU", pdu.pdu_id)
|
||||
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.pdu_id)
|
||||
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
|
||||
|
||||
# TODO, add errback, etc.
|
||||
self._transaction_queue.enqueue_pdu(pdu, order)
|
||||
|
||||
logger.debug("[%s] transaction_layer.enqueue_pdu... done", pdu.pdu_id)
|
||||
logger.debug(
|
||||
"[%s] transaction_layer.enqueue_pdu... done",
|
||||
pdu.event_id
|
||||
)
|
||||
|
||||
@log_function
|
||||
def send_edu(self, destination, edu_type, content):
|
||||
@ -158,6 +148,11 @@ class ReplicationLayer(object):
|
||||
self._transaction_queue.enqueue_edu(edu)
|
||||
return defer.succeed(None)
|
||||
|
||||
@log_function
|
||||
def send_failure(self, failure, destination):
|
||||
self._transaction_queue.enqueue_failure(failure, destination)
|
||||
return defer.succeed(None)
|
||||
|
||||
@log_function
|
||||
def make_query(self, destination, query_type, args,
|
||||
retry_on_dns_fail=True):
|
||||
@ -181,7 +176,7 @@ class ReplicationLayer(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def backfill(self, dest, context, limit):
|
||||
def backfill(self, dest, context, limit, extremities):
|
||||
"""Requests some more historic PDUs for the given context from the
|
||||
given destination server.
|
||||
|
||||
@ -189,12 +184,12 @@ class ReplicationLayer(object):
|
||||
dest (str): The remote home server to ask.
|
||||
context (str): The context to backfill.
|
||||
limit (int): The maximum number of PDUs to return.
|
||||
extremities (list): List of PDU id and origins of the first pdus
|
||||
we have seen from the context
|
||||
|
||||
Returns:
|
||||
Deferred: Results in the received PDUs.
|
||||
"""
|
||||
extremities = yield self.store.get_oldest_pdus_in_context(context)
|
||||
|
||||
logger.debug("backfill extrem=%s", extremities)
|
||||
|
||||
# If there are no extremeties then we've (probably) reached the start.
|
||||
@ -210,13 +205,13 @@ class ReplicationLayer(object):
|
||||
|
||||
pdus = [Pdu(outlier=False, **p) for p in transaction.pdus]
|
||||
for pdu in pdus:
|
||||
yield self._handle_new_pdu(pdu, backfilled=True)
|
||||
yield self._handle_new_pdu(dest, pdu, backfilled=True)
|
||||
|
||||
defer.returnValue(pdus)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_pdu(self, destination, pdu_origin, pdu_id, outlier=False):
|
||||
def get_pdu(self, destination, event_id, outlier=False):
|
||||
"""Requests the PDU with given origin and ID from the remote home
|
||||
server.
|
||||
|
||||
@ -225,7 +220,7 @@ class ReplicationLayer(object):
|
||||
Args:
|
||||
destination (str): Which home server to query
|
||||
pdu_origin (str): The home server that originally sent the pdu.
|
||||
pdu_id (str)
|
||||
event_id (str)
|
||||
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
|
||||
it's from an arbitary point in the context as opposed to part
|
||||
of the current block of PDUs. Defaults to `False`
|
||||
@ -234,8 +229,9 @@ class ReplicationLayer(object):
|
||||
Deferred: Results in the requested PDU.
|
||||
"""
|
||||
|
||||
transaction_data = yield self.transport_layer.get_pdu(
|
||||
destination, pdu_origin, pdu_id)
|
||||
transaction_data = yield self.transport_layer.get_event(
|
||||
destination, event_id
|
||||
)
|
||||
|
||||
transaction = Transaction(**transaction_data)
|
||||
|
||||
@ -244,13 +240,13 @@ class ReplicationLayer(object):
|
||||
pdu = None
|
||||
if pdu_list:
|
||||
pdu = pdu_list[0]
|
||||
yield self._handle_new_pdu(pdu)
|
||||
yield self._handle_new_pdu(destination, pdu)
|
||||
|
||||
defer.returnValue(pdu)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_state_for_context(self, destination, context):
|
||||
def get_state_for_context(self, destination, context, event_id=None):
|
||||
"""Requests all of the `current` state PDUs for a given context from
|
||||
a remote home server.
|
||||
|
||||
@ -263,29 +259,25 @@ class ReplicationLayer(object):
|
||||
"""
|
||||
|
||||
transaction_data = yield self.transport_layer.get_context_state(
|
||||
destination, context)
|
||||
destination,
|
||||
context,
|
||||
event_id=event_id,
|
||||
)
|
||||
|
||||
transaction = Transaction(**transaction_data)
|
||||
|
||||
pdus = [Pdu(outlier=True, **p) for p in transaction.pdus]
|
||||
for pdu in pdus:
|
||||
yield self._handle_new_pdu(pdu)
|
||||
yield self._handle_new_pdu(destination, pdu)
|
||||
|
||||
defer.returnValue(pdus)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_context_pdus_request(self, context):
|
||||
pdus = yield self.pdu_actions.get_all_pdus_from_context(
|
||||
context
|
||||
def on_backfill_request(self, origin, context, versions, limit):
|
||||
pdus = yield self.handler.on_backfill_request(
|
||||
origin, context, versions, limit
|
||||
)
|
||||
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_backfill_request(self, context, versions, limit):
|
||||
|
||||
pdus = yield self.pdu_actions.backfill(context, versions, limit)
|
||||
|
||||
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
|
||||
|
||||
@ -295,6 +287,10 @@ class ReplicationLayer(object):
|
||||
transaction = Transaction(**transaction_data)
|
||||
|
||||
for p in transaction.pdus:
|
||||
if "unsigned" in p:
|
||||
unsigned = p["unsigned"]
|
||||
if "age" in unsigned:
|
||||
p["age"] = unsigned["age"]
|
||||
if "age" in p:
|
||||
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
|
||||
del p["age"]
|
||||
@ -315,11 +311,15 @@ class ReplicationLayer(object):
|
||||
|
||||
dl = []
|
||||
for pdu in pdu_list:
|
||||
dl.append(self._handle_new_pdu(pdu))
|
||||
dl.append(self._handle_new_pdu(transaction.origin, pdu))
|
||||
|
||||
if hasattr(transaction, "edus"):
|
||||
for edu in [Edu(**x) for x in transaction.edus]:
|
||||
self.received_edu(transaction.origin, edu.edu_type, edu.content)
|
||||
self.received_edu(
|
||||
transaction.origin,
|
||||
edu.edu_type,
|
||||
edu.content
|
||||
)
|
||||
|
||||
results = yield defer.DeferredList(dl)
|
||||
|
||||
@ -347,20 +347,22 @@ class ReplicationLayer(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_context_state_request(self, context):
|
||||
results = yield self.store.get_current_state_for_context(
|
||||
context
|
||||
)
|
||||
def on_context_state_request(self, origin, context, event_id):
|
||||
if event_id:
|
||||
pdus = yield self.handler.get_state_for_pdu(
|
||||
origin,
|
||||
context,
|
||||
event_id,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Specify an event")
|
||||
|
||||
logger.debug("Context returning %d results", len(results))
|
||||
|
||||
pdus = [Pdu.from_pdu_tuple(p) for p in results]
|
||||
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_pdu_request(self, pdu_origin, pdu_id):
|
||||
pdu = yield self._get_persisted_pdu(pdu_id, pdu_origin)
|
||||
def on_pdu_request(self, origin, event_id):
|
||||
pdu = yield self._get_persisted_pdu(origin, event_id)
|
||||
|
||||
if pdu:
|
||||
defer.returnValue(
|
||||
@ -372,20 +374,7 @@ class ReplicationLayer(object):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_pull_request(self, origin, versions):
|
||||
transaction_id = max([int(v) for v in versions])
|
||||
|
||||
response = yield self.pdu_actions.after_transaction(
|
||||
transaction_id,
|
||||
origin,
|
||||
self.server_name
|
||||
)
|
||||
|
||||
if not response:
|
||||
response = []
|
||||
|
||||
defer.returnValue(
|
||||
(200, self._transaction_from_pdus(response).get_dict())
|
||||
)
|
||||
raise NotImplementedError("Pull transacions not implemented")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_query_request(self, query_type, args):
|
||||
@ -393,82 +382,183 @@ class ReplicationLayer(object):
|
||||
response = yield self.query_handlers[query_type](args)
|
||||
defer.returnValue((200, response))
|
||||
else:
|
||||
defer.returnValue((404, "No handler for Query type '%s'"
|
||||
% (query_type)
|
||||
))
|
||||
defer.returnValue(
|
||||
(404, "No handler for Query type '%s'" % (query_type, ))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_make_join_request(self, context, user_id):
|
||||
pdu = yield self.handler.on_make_join_request(context, user_id)
|
||||
defer.returnValue({
|
||||
"event": pdu.get_dict(),
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_invite_request(self, origin, content):
|
||||
pdu = Pdu(**content)
|
||||
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
|
||||
defer.returnValue(
|
||||
(
|
||||
200,
|
||||
{
|
||||
"event": ret_pdu.get_dict(),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_send_join_request(self, origin, content):
|
||||
pdu = Pdu(**content)
|
||||
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"state": [p.get_dict() for p in res_pdus["state"]],
|
||||
"auth_chain": [p.get_dict() for p in res_pdus["auth_chain"]],
|
||||
}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_event_auth(self, origin, context, event_id):
|
||||
auth_pdus = yield self.handler.on_event_auth(event_id)
|
||||
defer.returnValue(
|
||||
(
|
||||
200,
|
||||
{
|
||||
"auth_chain": [a.get_dict() for a in auth_pdus],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def make_join(self, destination, context, user_id):
|
||||
ret = yield self.transport_layer.make_join(
|
||||
destination=destination,
|
||||
context=context,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
pdu_dict = ret["event"]
|
||||
|
||||
logger.debug("Got response to make_join: %s", pdu_dict)
|
||||
|
||||
defer.returnValue(Pdu(**pdu_dict))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_join(self, destination, pdu):
|
||||
_, content = yield self.transport_layer.send_join(
|
||||
destination,
|
||||
pdu.room_id,
|
||||
pdu.event_id,
|
||||
pdu.get_dict(),
|
||||
)
|
||||
|
||||
logger.debug("Got content: %s", content)
|
||||
state = [Pdu(outlier=True, **p) for p in content.get("state", [])]
|
||||
for pdu in state:
|
||||
yield self._handle_new_pdu(destination, pdu)
|
||||
|
||||
auth_chain = [
|
||||
Pdu(outlier=True, **p) for p in content.get("auth_chain", [])
|
||||
]
|
||||
for pdu in auth_chain:
|
||||
yield self._handle_new_pdu(destination, pdu)
|
||||
|
||||
defer.returnValue(state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_invite(self, destination, context, event_id, pdu):
|
||||
code, content = yield self.transport_layer.send_invite(
|
||||
destination=destination,
|
||||
context=context,
|
||||
event_id=event_id,
|
||||
content=pdu.get_dict(),
|
||||
)
|
||||
|
||||
pdu_dict = content["event"]
|
||||
|
||||
logger.debug("Got response to send_invite: %s", pdu_dict)
|
||||
|
||||
defer.returnValue(Pdu(**pdu_dict))
|
||||
|
||||
@log_function
|
||||
def _get_persisted_pdu(self, pdu_id, pdu_origin):
|
||||
def _get_persisted_pdu(self, origin, event_id):
|
||||
""" Get a PDU from the database with given origin and id.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a `Pdu`.
|
||||
"""
|
||||
pdu_tuple = yield self.store.get_pdu(pdu_id, pdu_origin)
|
||||
|
||||
defer.returnValue(Pdu.from_pdu_tuple(pdu_tuple))
|
||||
return self.handler.get_persisted_pdu(origin, event_id)
|
||||
|
||||
def _transaction_from_pdus(self, pdu_list):
|
||||
"""Returns a new Transaction containing the given PDUs suitable for
|
||||
transmission.
|
||||
"""
|
||||
pdus = [p.get_dict() for p in pdu_list]
|
||||
time_now = self._clock.time_msec()
|
||||
for p in pdus:
|
||||
if "age_ts" in pdus:
|
||||
p["age"] = int(self.clock.time_msec()) - p["age_ts"]
|
||||
|
||||
if "age_ts" in p:
|
||||
age = time_now - p["age_ts"]
|
||||
p.setdefault("unsigned", {})["age"] = int(age)
|
||||
del p["age_ts"]
|
||||
return Transaction(
|
||||
origin=self.server_name,
|
||||
pdus=pdus,
|
||||
origin_server_ts=int(self._clock.time_msec()),
|
||||
origin_server_ts=int(time_now),
|
||||
destination=None,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _handle_new_pdu(self, pdu, backfilled=False):
|
||||
def _handle_new_pdu(self, origin, pdu, backfilled=False):
|
||||
# We reprocess pdus when we have seen them only as outliers
|
||||
existing = yield self._get_persisted_pdu(pdu.pdu_id, pdu.origin)
|
||||
existing = yield self._get_persisted_pdu(origin, pdu.event_id)
|
||||
|
||||
if existing and (not existing.outlier or pdu.outlier):
|
||||
logger.debug("Already seen pdu %s %s", pdu.pdu_id, pdu.origin)
|
||||
logger.debug("Already seen pdu %s", pdu.event_id)
|
||||
defer.returnValue({})
|
||||
return
|
||||
|
||||
state = None
|
||||
|
||||
# Get missing pdus if necessary.
|
||||
is_new = yield self.pdu_actions.is_new(pdu)
|
||||
if is_new and not pdu.outlier:
|
||||
if not pdu.outlier:
|
||||
# We only backfill backwards to the min depth.
|
||||
min_depth = yield self.store.get_min_depth_for_context(pdu.context)
|
||||
min_depth = yield self.handler.get_min_depth_for_context(
|
||||
pdu.room_id
|
||||
)
|
||||
|
||||
if min_depth and pdu.depth > min_depth:
|
||||
for pdu_id, origin in pdu.prev_pdus:
|
||||
exists = yield self._get_persisted_pdu(pdu_id, origin)
|
||||
for event_id, hashes in pdu.prev_events:
|
||||
exists = yield self._get_persisted_pdu(origin, event_id)
|
||||
|
||||
if not exists:
|
||||
logger.debug("Requesting pdu %s %s", pdu_id, origin)
|
||||
logger.debug("Requesting pdu %s", event_id)
|
||||
|
||||
try:
|
||||
yield self.get_pdu(
|
||||
pdu.origin,
|
||||
pdu_id=pdu_id,
|
||||
pdu_origin=origin
|
||||
event_id=event_id,
|
||||
)
|
||||
logger.debug("Processed pdu %s %s", pdu_id, origin)
|
||||
logger.debug("Processed pdu %s", event_id)
|
||||
except:
|
||||
# TODO(erikj): Do some more intelligent retries.
|
||||
logger.exception("Failed to get PDU")
|
||||
|
||||
# Persist the Pdu, but don't mark it as processed yet.
|
||||
yield self.store.persist_event(pdu=pdu)
|
||||
else:
|
||||
# We need to get the state at this event, since we have reached
|
||||
# a backward extremity edge.
|
||||
state = yield self.get_state_for_context(
|
||||
origin, pdu.room_id, pdu.event_id,
|
||||
)
|
||||
|
||||
if not backfilled:
|
||||
ret = yield self.handler.on_receive_pdu(pdu, backfilled=backfilled)
|
||||
ret = yield self.handler.on_receive_pdu(
|
||||
pdu,
|
||||
backfilled=backfilled,
|
||||
state=state,
|
||||
)
|
||||
else:
|
||||
ret = None
|
||||
|
||||
yield self.pdu_actions.mark_as_processed(pdu)
|
||||
# yield self.pdu_actions.mark_as_processed(pdu)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
@ -476,14 +566,6 @@ class ReplicationLayer(object):
|
||||
return "<ReplicationLayer(%s)>" % self.server_name
|
||||
|
||||
|
||||
class ReplicationHandler(object):
|
||||
"""This defines the methods that the :py:class:`.ReplicationLayer` will
|
||||
use to communicate with the rest of the home server.
|
||||
"""
|
||||
def on_receive_pdu(self, pdu):
|
||||
raise NotImplementedError("on_receive_pdu")
|
||||
|
||||
|
||||
class _TransactionQueue(object):
|
||||
"""This class makes sure we only have one transaction in flight at
|
||||
a time for a given destination.
|
||||
@ -509,6 +591,9 @@ class _TransactionQueue(object):
|
||||
# destination -> list of tuple(edu, deferred)
|
||||
self.pending_edus_by_dest = {}
|
||||
|
||||
# destination -> list of tuple(failure, deferred)
|
||||
self.pending_failures_by_dest = {}
|
||||
|
||||
# HACK to get unique tx id
|
||||
self._next_txn_id = int(self._clock.time_msec())
|
||||
|
||||
@ -561,6 +646,18 @@ class _TransactionQueue(object):
|
||||
|
||||
return deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def enqueue_failure(self, failure, destination):
|
||||
deferred = defer.Deferred()
|
||||
|
||||
self.pending_failures_by_dest.setdefault(
|
||||
destination, []
|
||||
).append(
|
||||
(failure, deferred)
|
||||
)
|
||||
|
||||
yield deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _attempt_new_transaction(self, destination):
|
||||
@ -570,8 +667,9 @@ class _TransactionQueue(object):
|
||||
# list of (pending_pdu, deferred, order)
|
||||
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
|
||||
pending_edus = self.pending_edus_by_dest.pop(destination, [])
|
||||
pending_failures = self.pending_failures_by_dest.pop(destination, [])
|
||||
|
||||
if not pending_pdus and not pending_edus:
|
||||
if not pending_pdus and not pending_edus and not pending_failures:
|
||||
return
|
||||
|
||||
logger.debug("TX [%s] Attempting new transaction", destination)
|
||||
@ -581,7 +679,11 @@ class _TransactionQueue(object):
|
||||
|
||||
pdus = [x[0] for x in pending_pdus]
|
||||
edus = [x[0] for x in pending_edus]
|
||||
deferreds = [x[1] for x in pending_pdus + pending_edus]
|
||||
failures = [x[0].get_dict() for x in pending_failures]
|
||||
deferreds = [
|
||||
x[1]
|
||||
for x in pending_pdus + pending_edus + pending_failures
|
||||
]
|
||||
|
||||
try:
|
||||
self.pending_transactions[destination] = 1
|
||||
@ -589,12 +691,13 @@ class _TransactionQueue(object):
|
||||
logger.debug("TX [%s] Persisting transaction...", destination)
|
||||
|
||||
transaction = Transaction.create_new(
|
||||
origin_server_ts=self._clock.time_msec(),
|
||||
origin_server_ts=int(self._clock.time_msec()),
|
||||
transaction_id=str(self._next_txn_id),
|
||||
origin=self.server_name,
|
||||
destination=destination,
|
||||
pdus=pdus,
|
||||
edus=edus,
|
||||
pdu_failures=failures,
|
||||
)
|
||||
|
||||
self._next_txn_id += 1
|
||||
@ -614,7 +717,9 @@ class _TransactionQueue(object):
|
||||
if "pdus" in data:
|
||||
for p in data["pdus"]:
|
||||
if "age_ts" in p:
|
||||
p["age"] = now - int(p["age_ts"])
|
||||
unsigned = p.setdefault("unsigned", {})
|
||||
unsigned["age"] = now - int(p["age_ts"])
|
||||
del p["age_ts"]
|
||||
return data
|
||||
|
||||
code, response = yield self.transport_layer.send_transaction(
|
||||
|
@ -72,7 +72,7 @@ class TransportLayer(object):
|
||||
self.received_handler = None
|
||||
|
||||
@log_function
|
||||
def get_context_state(self, destination, context):
|
||||
def get_context_state(self, destination, context, event_id=None):
|
||||
""" Requests all state for a given context (i.e. room) from the
|
||||
given server.
|
||||
|
||||
@ -89,54 +89,62 @@ class TransportLayer(object):
|
||||
|
||||
subpath = "/state/%s/" % context
|
||||
|
||||
return self._do_request_for_transaction(destination, subpath)
|
||||
args = {}
|
||||
if event_id:
|
||||
args["event_id"] = event_id
|
||||
|
||||
return self._do_request_for_transaction(
|
||||
destination, subpath, args=args
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_pdu(self, destination, pdu_origin, pdu_id):
|
||||
def get_event(self, destination, event_id):
|
||||
""" Requests the pdu with give id and origin from the given server.
|
||||
|
||||
Args:
|
||||
destination (str): The host name of the remote home server we want
|
||||
to get the state from.
|
||||
pdu_origin (str): The home server which created the PDU.
|
||||
pdu_id (str): The id of the PDU being requested.
|
||||
event_id (str): The id of the event being requested.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
logger.debug("get_pdu dest=%s, pdu_origin=%s, pdu_id=%s",
|
||||
destination, pdu_origin, pdu_id)
|
||||
logger.debug("get_pdu dest=%s, event_id=%s",
|
||||
destination, event_id)
|
||||
|
||||
subpath = "/pdu/%s/%s/" % (pdu_origin, pdu_id)
|
||||
subpath = "/event/%s/" % (event_id, )
|
||||
|
||||
return self._do_request_for_transaction(destination, subpath)
|
||||
|
||||
@log_function
|
||||
def backfill(self, dest, context, pdu_tuples, limit):
|
||||
def backfill(self, dest, context, event_tuples, limit):
|
||||
""" Requests `limit` previous PDUs in a given context before list of
|
||||
PDUs.
|
||||
|
||||
Args:
|
||||
dest (str)
|
||||
context (str)
|
||||
pdu_tuples (list)
|
||||
event_tuples (list)
|
||||
limt (int)
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
logger.debug(
|
||||
"backfill dest=%s, context=%s, pdu_tuples=%s, limit=%s",
|
||||
dest, context, repr(pdu_tuples), str(limit)
|
||||
"backfill dest=%s, context=%s, event_tuples=%s, limit=%s",
|
||||
dest, context, repr(event_tuples), str(limit)
|
||||
)
|
||||
|
||||
if not pdu_tuples:
|
||||
if not event_tuples:
|
||||
# TODO: raise?
|
||||
return
|
||||
|
||||
subpath = "/backfill/%s/" % context
|
||||
subpath = "/backfill/%s/" % (context,)
|
||||
|
||||
args = {"v": ["%s,%s" % (i, o) for i, o in pdu_tuples]}
|
||||
args["limit"] = limit
|
||||
args = {
|
||||
"v": event_tuples,
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
return self._do_request_for_transaction(
|
||||
dest,
|
||||
@ -197,6 +205,72 @@ class TransportLayer(object):
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def make_join(self, destination, context, user_id, retry_on_dns_fail=True):
|
||||
path = PREFIX + "/make_join/%s/%s" % (context, user_id,)
|
||||
|
||||
response = yield self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
retry_on_dns_fail=retry_on_dns_fail,
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_join(self, destination, context, event_id, content):
|
||||
path = PREFIX + "/send_join/%s/%s" % (
|
||||
context,
|
||||
event_id,
|
||||
)
|
||||
|
||||
code, content = yield self.client.put_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
)
|
||||
|
||||
if not 200 <= code < 300:
|
||||
raise RuntimeError("Got %d from send_join", code)
|
||||
|
||||
defer.returnValue(json.loads(content))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_invite(self, destination, context, event_id, content):
|
||||
path = PREFIX + "/invite/%s/%s" % (
|
||||
context,
|
||||
event_id,
|
||||
)
|
||||
|
||||
code, content = yield self.client.put_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
)
|
||||
|
||||
if not 200 <= code < 300:
|
||||
raise RuntimeError("Got %d from send_invite", code)
|
||||
|
||||
defer.returnValue(json.loads(content))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_event_auth(self, destination, context, event_id):
|
||||
path = PREFIX + "/event_auth/%s/%s" % (
|
||||
context,
|
||||
event_id,
|
||||
)
|
||||
|
||||
response = yield self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _authenticate_request(self, request):
|
||||
json_request = {
|
||||
@ -210,7 +284,7 @@ class TransportLayer(object):
|
||||
origin = None
|
||||
|
||||
if request.method == "PUT":
|
||||
#TODO: Handle other method types? other content types?
|
||||
# TODO: Handle other method types? other content types?
|
||||
try:
|
||||
content_bytes = request.content.read()
|
||||
content = json.loads(content_bytes)
|
||||
@ -222,11 +296,13 @@ class TransportLayer(object):
|
||||
try:
|
||||
params = auth.split(" ")[1].split(",")
|
||||
param_dict = dict(kv.split("=") for kv in params)
|
||||
|
||||
def strip_quotes(value):
|
||||
if value.startswith("\""):
|
||||
return value[1:-1]
|
||||
else:
|
||||
return value
|
||||
|
||||
origin = strip_quotes(param_dict["origin"])
|
||||
key = strip_quotes(param_dict["key"])
|
||||
sig = strip_quotes(param_dict["sig"])
|
||||
@ -247,7 +323,7 @@ class TransportLayer(object):
|
||||
if auth.startswith("X-Matrix"):
|
||||
(origin, key, sig) = parse_auth_header(auth)
|
||||
json_request["origin"] = origin
|
||||
json_request["signatures"].setdefault(origin,{})[key] = sig
|
||||
json_request["signatures"].setdefault(origin, {})[key] = sig
|
||||
|
||||
if not json_request["signatures"]:
|
||||
raise SynapseError(
|
||||
@ -313,10 +389,10 @@ class TransportLayer(object):
|
||||
# data_id pair.
|
||||
self.server.register_path(
|
||||
"GET",
|
||||
re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
|
||||
re.compile("^" + PREFIX + "/event/([^/]*)/$"),
|
||||
self._with_authentication(
|
||||
lambda origin, content, query, pdu_origin, pdu_id:
|
||||
handler.on_pdu_request(pdu_origin, pdu_id)
|
||||
lambda origin, content, query, event_id:
|
||||
handler.on_pdu_request(origin, event_id)
|
||||
)
|
||||
)
|
||||
|
||||
@ -326,7 +402,11 @@ class TransportLayer(object):
|
||||
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
|
||||
self._with_authentication(
|
||||
lambda origin, content, query, context:
|
||||
handler.on_context_state_request(context)
|
||||
handler.on_context_state_request(
|
||||
origin,
|
||||
context,
|
||||
query.get("event_id", [None])[0],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@ -336,20 +416,11 @@ class TransportLayer(object):
|
||||
self._with_authentication(
|
||||
lambda origin, content, query, context:
|
||||
self._on_backfill_request(
|
||||
context, query["v"], query["limit"]
|
||||
origin, context, query["v"], query["limit"]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.server.register_path(
|
||||
"GET",
|
||||
re.compile("^" + PREFIX + "/context/([^/]*)/$"),
|
||||
self._with_authentication(
|
||||
lambda origin, content, query, context:
|
||||
handler.on_context_pdus_request(context)
|
||||
)
|
||||
)
|
||||
|
||||
# This is when we receive a server-server Query
|
||||
self.server.register_path(
|
||||
"GET",
|
||||
@ -362,6 +433,50 @@ class TransportLayer(object):
|
||||
)
|
||||
)
|
||||
|
||||
self.server.register_path(
|
||||
"GET",
|
||||
re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
|
||||
self._with_authentication(
|
||||
lambda origin, content, query, context, user_id:
|
||||
self._on_make_join_request(
|
||||
origin, content, query, context, user_id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.server.register_path(
|
||||
"GET",
|
||||
re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"),
|
||||
self._with_authentication(
|
||||
lambda origin, content, query, context, event_id:
|
||||
handler.on_event_auth(
|
||||
origin, context, event_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.server.register_path(
|
||||
"PUT",
|
||||
re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),
|
||||
self._with_authentication(
|
||||
lambda origin, content, query, context, event_id:
|
||||
self._on_send_join_request(
|
||||
origin, content, query,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
self.server.register_path(
|
||||
"PUT",
|
||||
re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"),
|
||||
self._with_authentication(
|
||||
lambda origin, content, query, context, event_id:
|
||||
self._on_invite_request(
|
||||
origin, content, query,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _on_send_request(self, origin, content, query, transaction_id):
|
||||
@ -402,7 +517,8 @@ class TransportLayer(object):
|
||||
return
|
||||
|
||||
try:
|
||||
code, response = yield self.received_handler.on_incoming_transaction(
|
||||
handler = self.received_handler
|
||||
code, response = yield handler.on_incoming_transaction(
|
||||
transaction_data
|
||||
)
|
||||
except:
|
||||
@ -440,7 +556,7 @@ class TransportLayer(object):
|
||||
defer.returnValue(data)
|
||||
|
||||
@log_function
|
||||
def _on_backfill_request(self, context, v_list, limits):
|
||||
def _on_backfill_request(self, origin, context, v_list, limits):
|
||||
if not limits:
|
||||
return defer.succeed(
|
||||
(400, {"error": "Did not include limit param"})
|
||||
@ -448,124 +564,34 @@ class TransportLayer(object):
|
||||
|
||||
limit = int(limits[-1])
|
||||
|
||||
versions = [v.split(",", 1) for v in v_list]
|
||||
versions = v_list
|
||||
|
||||
return self.request_handler.on_backfill_request(
|
||||
context, versions, limit)
|
||||
origin, context, versions, limit
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _on_make_join_request(self, origin, content, query, context, user_id):
|
||||
content = yield self.request_handler.on_make_join_request(
|
||||
context, user_id,
|
||||
)
|
||||
defer.returnValue((200, content))
|
||||
|
||||
class TransportReceivedHandler(object):
|
||||
""" Callbacks used when we receive a transaction
|
||||
"""
|
||||
def on_incoming_transaction(self, transaction):
|
||||
""" Called on PUT /send/<transaction_id>, or on response to a request
|
||||
that we sent (e.g. a backfill request)
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _on_send_join_request(self, origin, content, query):
|
||||
content = yield self.request_handler.on_send_join_request(
|
||||
origin, content,
|
||||
)
|
||||
|
||||
Args:
|
||||
transaction (synapse.transaction.Transaction): The transaction that
|
||||
was sent to us.
|
||||
defer.returnValue((200, content))
|
||||
|
||||
Returns:
|
||||
twisted.internet.defer.Deferred: A deferred that gets fired when
|
||||
the transaction has finished being processed.
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _on_invite_request(self, origin, content, query):
|
||||
content = yield self.request_handler.on_invite_request(
|
||||
origin, content,
|
||||
)
|
||||
|
||||
The result should be a tuple in the form of
|
||||
`(response_code, respond_body)`, where `response_body` is a python
|
||||
dict that will get serialized to JSON.
|
||||
|
||||
On errors, the dict should have an `error` key with a brief message
|
||||
of what went wrong.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TransportRequestHandler(object):
|
||||
""" Handlers used when someone want's data from us
|
||||
"""
|
||||
def on_pull_request(self, versions):
|
||||
""" Called on GET /pull/?v=...
|
||||
|
||||
This is hit when a remote home server wants to get all data
|
||||
after a given transaction. Mainly used when a home server comes back
|
||||
online and wants to get everything it has missed.
|
||||
|
||||
Args:
|
||||
versions (list): A list of transaction_ids that should be used to
|
||||
determine what PDUs the remote side have not yet seen.
|
||||
|
||||
Returns:
|
||||
Deferred: Resultsin a tuple in the form of
|
||||
`(response_code, respond_body)`, where `response_body` is a python
|
||||
dict that will get serialized to JSON.
|
||||
|
||||
On errors, the dict should have an `error` key with a brief message
|
||||
of what went wrong.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_pdu_request(self, pdu_origin, pdu_id):
|
||||
""" Called on GET /pdu/<pdu_origin>/<pdu_id>/
|
||||
|
||||
Someone wants a particular PDU. This PDU may or may not have originated
|
||||
from us.
|
||||
|
||||
Args:
|
||||
pdu_origin (str)
|
||||
pdu_id (str)
|
||||
|
||||
Returns:
|
||||
Deferred: Resultsin a tuple in the form of
|
||||
`(response_code, respond_body)`, where `response_body` is a python
|
||||
dict that will get serialized to JSON.
|
||||
|
||||
On errors, the dict should have an `error` key with a brief message
|
||||
of what went wrong.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_context_state_request(self, context):
|
||||
""" Called on GET /state/<context>/
|
||||
|
||||
Gets hit when someone wants all the *current* state for a given
|
||||
contexts.
|
||||
|
||||
Args:
|
||||
context (str): The name of the context that we're interested in.
|
||||
|
||||
Returns:
|
||||
twisted.internet.defer.Deferred: A deferred that gets fired when
|
||||
the transaction has finished being processed.
|
||||
|
||||
The result should be a tuple in the form of
|
||||
`(response_code, respond_body)`, where `response_body` is a python
|
||||
dict that will get serialized to JSON.
|
||||
|
||||
On errors, the dict should have an `error` key with a brief message
|
||||
of what went wrong.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_backfill_request(self, context, versions, limit):
|
||||
""" Called on GET /backfill/<context>/?v=...&limit=...
|
||||
|
||||
Gets hit when we want to backfill backwards on a given context from
|
||||
the given point.
|
||||
|
||||
Args:
|
||||
context (str): The context to backfill
|
||||
versions (list): A list of 2-tuples representing where to backfill
|
||||
from, in the form `(pdu_id, origin)`
|
||||
limit (int): How many pdus to return.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a tuple in the form of
|
||||
`(response_code, respond_body)`, where `response_body` is a python
|
||||
dict that will get serialized to JSON.
|
||||
|
||||
On errors, the dict should have an `error` key with a brief message
|
||||
of what went wrong.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_query_request(self):
|
||||
""" Called on a GET /query/<query_type> request. """
|
||||
defer.returnValue((200, content))
|
||||
|
@ -20,8 +20,6 @@ server protocol.
|
||||
from synapse.util.jsonobject import JsonEncodedObject
|
||||
|
||||
import logging
|
||||
import json
|
||||
import copy
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -33,13 +31,13 @@ class Pdu(JsonEncodedObject):
|
||||
|
||||
A Pdu can be classified as "state". For a given context, we can efficiently
|
||||
retrieve all state pdu's that haven't been clobbered. Clobbering is done
|
||||
via a unique constraint on the tuple (context, pdu_type, state_key). A pdu
|
||||
via a unique constraint on the tuple (context, type, state_key). A pdu
|
||||
is a state pdu if `is_state` is True.
|
||||
|
||||
Example pdu::
|
||||
|
||||
{
|
||||
"pdu_id": "78c",
|
||||
"event_id": "$78c:example.com",
|
||||
"origin_server_ts": 1404835423000,
|
||||
"origin": "bar",
|
||||
"prev_ids": [
|
||||
@ -52,24 +50,21 @@ class Pdu(JsonEncodedObject):
|
||||
"""
|
||||
|
||||
valid_keys = [
|
||||
"pdu_id",
|
||||
"context",
|
||||
"event_id",
|
||||
"room_id",
|
||||
"origin",
|
||||
"origin_server_ts",
|
||||
"pdu_type",
|
||||
"type",
|
||||
"destinations",
|
||||
"transaction_id",
|
||||
"prev_pdus",
|
||||
"prev_events",
|
||||
"depth",
|
||||
"content",
|
||||
"outlier",
|
||||
"is_state", # Below this are keys valid only for State Pdus.
|
||||
"state_key",
|
||||
"power_level",
|
||||
"prev_state_id",
|
||||
"prev_state_origin",
|
||||
"required_power_level",
|
||||
"hashes",
|
||||
"user_id",
|
||||
"auth_events",
|
||||
"signatures", # Below this are keys valid only for State Pdus.
|
||||
"state_key",
|
||||
"prev_state",
|
||||
]
|
||||
|
||||
internal_keys = [
|
||||
@ -79,61 +74,28 @@ class Pdu(JsonEncodedObject):
|
||||
]
|
||||
|
||||
required_keys = [
|
||||
"pdu_id",
|
||||
"context",
|
||||
"event_id",
|
||||
"room_id",
|
||||
"origin",
|
||||
"origin_server_ts",
|
||||
"pdu_type",
|
||||
"type",
|
||||
"content",
|
||||
]
|
||||
|
||||
# TODO: We need to make this properly load content rather than
|
||||
# just leaving it as a dict. (OR DO WE?!)
|
||||
|
||||
def __init__(self, destinations=[], is_state=False, prev_pdus=[],
|
||||
outlier=False, **kwargs):
|
||||
if is_state:
|
||||
for required_key in ["state_key"]:
|
||||
if required_key not in kwargs:
|
||||
raise RuntimeError("Key %s is required" % required_key)
|
||||
|
||||
def __init__(self, destinations=[], prev_events=[],
|
||||
outlier=False, hashes={}, signatures={}, **kwargs):
|
||||
super(Pdu, self).__init__(
|
||||
destinations=destinations,
|
||||
is_state=is_state,
|
||||
prev_pdus=prev_pdus,
|
||||
prev_events=prev_events,
|
||||
outlier=outlier,
|
||||
hashes=hashes,
|
||||
signatures=signatures,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pdu_tuple(cls, pdu_tuple):
|
||||
""" Converts a PduTuple to a Pdu
|
||||
|
||||
Args:
|
||||
pdu_tuple (synapse.persistence.transactions.PduTuple): The tuple to
|
||||
convert
|
||||
|
||||
Returns:
|
||||
Pdu
|
||||
"""
|
||||
if pdu_tuple:
|
||||
d = copy.copy(pdu_tuple.pdu_entry._asdict())
|
||||
d["origin_server_ts"] = d.pop("ts")
|
||||
|
||||
d["content"] = json.loads(d["content_json"])
|
||||
del d["content_json"]
|
||||
|
||||
args = {f: d[f] for f in cls.valid_keys if f in d}
|
||||
if "unrecognized_keys" in d and d["unrecognized_keys"]:
|
||||
args.update(json.loads(d["unrecognized_keys"]))
|
||||
|
||||
return Pdu(
|
||||
prev_pdus=pdu_tuple.prev_pdu_list,
|
||||
**args
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
def __str__(self):
|
||||
return "(%s, %s)" % (self.__class__.__name__, repr(self.__dict__))
|
||||
|
||||
@ -193,6 +155,7 @@ class Transaction(JsonEncodedObject):
|
||||
"edus",
|
||||
"transaction_id",
|
||||
"destination",
|
||||
"pdu_failures",
|
||||
]
|
||||
|
||||
internal_keys = [
|
||||
@ -229,7 +192,9 @@ class Transaction(JsonEncodedObject):
|
||||
transaction_id and origin_server_ts keys.
|
||||
"""
|
||||
if "origin_server_ts" not in kwargs:
|
||||
raise KeyError("Require 'origin_server_ts' to construct a Transaction")
|
||||
raise KeyError(
|
||||
"Require 'origin_server_ts' to construct a Transaction"
|
||||
)
|
||||
if "transaction_id" not in kwargs:
|
||||
raise KeyError(
|
||||
"Require 'transaction_id' to construct a Transaction"
|
||||
@ -241,6 +206,3 @@ class Transaction(JsonEncodedObject):
|
||||
kwargs["pdus"] = [p.get_dict() for p in pdus]
|
||||
|
||||
return Transaction(**kwargs)
|
||||
|
||||
|
||||
|
||||
|
@ -14,7 +14,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import LimitExceededError
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.api.events.room import RoomMemberEvent
|
||||
from synapse.api.constants import Membership
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseHandler(object):
|
||||
|
||||
@ -30,6 +41,9 @@ class BaseHandler(object):
|
||||
self.clock = hs.get_clock()
|
||||
self.hs = hs
|
||||
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.server_name = hs.hostname
|
||||
|
||||
def ratelimit(self, user_id):
|
||||
time_now = self.clock.time()
|
||||
allowed, time_allowed = self.ratelimiter.send_message(
|
||||
@ -44,16 +58,58 @@ class BaseHandler(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _on_new_room_event(self, event, snapshot, extra_destinations=[],
|
||||
extra_users=[]):
|
||||
extra_users=[], suppress_auth=False,
|
||||
do_invite_host=None):
|
||||
yield run_on_reactor()
|
||||
|
||||
snapshot.fill_out_prev_events(event)
|
||||
|
||||
yield self.state_handler.annotate_state_groups(event)
|
||||
|
||||
yield self.auth.add_auth_events(event)
|
||||
|
||||
logger.debug("Signing event...")
|
||||
|
||||
add_hashes_and_signatures(
|
||||
event, self.server_name, self.signing_key
|
||||
)
|
||||
|
||||
logger.debug("Signed event.")
|
||||
|
||||
if not suppress_auth:
|
||||
logger.debug("Authing...")
|
||||
self.auth.check(event, raises=True)
|
||||
logger.debug("Authed")
|
||||
else:
|
||||
logger.debug("Suppressed auth.")
|
||||
|
||||
if do_invite_host:
|
||||
federation_handler = self.hs.get_handlers().federation_handler
|
||||
invite_event = yield federation_handler.send_invite(
|
||||
do_invite_host,
|
||||
event
|
||||
)
|
||||
|
||||
# FIXME: We need to check if the remote changed anything else
|
||||
event.signatures = invite_event.signatures
|
||||
|
||||
yield self.store.persist_event(event)
|
||||
|
||||
destinations = set(extra_destinations)
|
||||
# Send a PDU to all hosts who have joined the room.
|
||||
destinations.update((yield self.store.get_joined_hosts_for_room(
|
||||
event.room_id
|
||||
)))
|
||||
|
||||
for k, s in event.state_events.items():
|
||||
try:
|
||||
if k[0] == RoomMemberEvent.TYPE:
|
||||
if s.content["membership"] == Membership.JOIN:
|
||||
destinations.add(
|
||||
self.hs.parse_userid(s.state_key).domain
|
||||
)
|
||||
except:
|
||||
logger.warn(
|
||||
"Failed to get destination from event %s", s.event_id
|
||||
)
|
||||
|
||||
event.destinations = list(destinations)
|
||||
|
||||
self.notifier.on_new_room_event(event, extra_users=extra_users)
|
||||
|
@ -147,10 +147,8 @@ class DirectoryHandler(BaseHandler):
|
||||
content={"aliases": aliases},
|
||||
)
|
||||
|
||||
snapshot = yield self.store.snapshot_room(
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
|
||||
yield self.state_handler.handle_new_event(event, snapshot)
|
||||
yield self._on_new_room_event(event, snapshot, extra_users=[user_id])
|
||||
yield self._on_new_room_event(
|
||||
event, snapshot, extra_users=[user_id], suppress_auth=True
|
||||
)
|
||||
|
@ -17,13 +17,15 @@
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
from synapse.api.events.room import InviteJoinEvent, RoomMemberEvent
|
||||
from synapse.api.errors import AuthError, FederationError
|
||||
from synapse.api.events.room import RoomMemberEvent
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.federation.pdu_codec import PduCodec
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.crypto.event_signing import compute_event_signature
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
|
||||
@ -62,6 +64,9 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
self.pdu_codec = PduCodec(hs)
|
||||
|
||||
# When joining a room we need to queue any events for that room up
|
||||
self.room_queues = {}
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def handle_new_event(self, event, snapshot):
|
||||
@ -78,6 +83,8 @@ class FederationHandler(BaseHandler):
|
||||
processing.
|
||||
"""
|
||||
|
||||
yield run_on_reactor()
|
||||
|
||||
pdu = self.pdu_codec.pdu_from_event(event)
|
||||
|
||||
if not hasattr(pdu, "destinations") or not pdu.destinations:
|
||||
@ -87,97 +94,88 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def on_receive_pdu(self, pdu, backfilled):
|
||||
def on_receive_pdu(self, pdu, backfilled, state=None):
|
||||
""" Called by the ReplicationLayer when we have a new pdu. We need to
|
||||
do auth checks and put it throught the StateHandler.
|
||||
do auth checks and put it through the StateHandler.
|
||||
"""
|
||||
event = self.pdu_codec.event_from_pdu(pdu)
|
||||
|
||||
logger.debug("Got event: %s", event.event_id)
|
||||
|
||||
with (yield self.lock_manager.lock(pdu.context)):
|
||||
if event.is_state and not backfilled:
|
||||
is_new_state = yield self.state_handler.handle_new_state(
|
||||
pdu
|
||||
)
|
||||
else:
|
||||
is_new_state = False
|
||||
if event.room_id in self.room_queues:
|
||||
self.room_queues[event.room_id].append(pdu)
|
||||
return
|
||||
|
||||
logger.debug("Processing event: %s", event.event_id)
|
||||
|
||||
if state:
|
||||
state = [self.pdu_codec.event_from_pdu(p) for p in state]
|
||||
|
||||
is_new_state = yield self.state_handler.annotate_state_groups(
|
||||
event,
|
||||
old_state=state
|
||||
)
|
||||
|
||||
logger.debug("Event: %s", event)
|
||||
|
||||
try:
|
||||
self.auth.check(event, raises=True)
|
||||
except AuthError as e:
|
||||
raise FederationError(
|
||||
"ERROR",
|
||||
e.code,
|
||||
e.msg,
|
||||
affected=event.event_id,
|
||||
)
|
||||
|
||||
is_new_state = is_new_state and not backfilled
|
||||
|
||||
# TODO: Implement something in federation that allows us to
|
||||
# respond to PDU.
|
||||
|
||||
target_is_mine = False
|
||||
if hasattr(event, "target_host"):
|
||||
target_is_mine = event.target_host == self.hs.hostname
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
backfilled,
|
||||
is_new_state=is_new_state
|
||||
)
|
||||
|
||||
if event.type == InviteJoinEvent.TYPE:
|
||||
if not target_is_mine:
|
||||
logger.debug("Ignoring invite/join event %s", event)
|
||||
return
|
||||
room = yield self.store.get_room(event.room_id)
|
||||
|
||||
# If we receive an invite/join event then we need to join the
|
||||
# sender to the given room.
|
||||
# TODO: We should probably auth this or some such
|
||||
content = event.content
|
||||
content.update({"membership": Membership.JOIN})
|
||||
new_event = self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
state_key=event.user_id,
|
||||
room_id=event.room_id,
|
||||
user_id=event.user_id,
|
||||
membership=Membership.JOIN,
|
||||
content=content
|
||||
)
|
||||
|
||||
yield self.hs.get_handlers().room_member_handler.change_membership(
|
||||
new_event,
|
||||
do_auth=False,
|
||||
)
|
||||
|
||||
else:
|
||||
with (yield self.room_lock.lock(event.room_id)):
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
backfilled,
|
||||
is_new_state=is_new_state
|
||||
if not room:
|
||||
# Huh, let's try and get the current state
|
||||
try:
|
||||
yield self.replication_layer.get_state_for_context(
|
||||
event.origin, event.room_id, event.event_id,
|
||||
)
|
||||
|
||||
room = yield self.store.get_room(event.room_id)
|
||||
|
||||
if not room:
|
||||
# Huh, let's try and get the current state
|
||||
try:
|
||||
yield self.replication_layer.get_state_for_context(
|
||||
event.origin, event.room_id
|
||||
)
|
||||
|
||||
hosts = yield self.store.get_joined_hosts_for_room(
|
||||
event.room_id
|
||||
)
|
||||
if self.hs.hostname in hosts:
|
||||
try:
|
||||
yield self.store.store_room(
|
||||
room_id=event.room_id,
|
||||
room_creator_user_id="",
|
||||
is_public=False,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to get current state for room %s",
|
||||
event.room_id
|
||||
)
|
||||
|
||||
if not backfilled:
|
||||
extra_users = []
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
target_user_id = event.state_key
|
||||
target_user = self.hs.parse_userid(target_user_id)
|
||||
extra_users.append(target_user)
|
||||
|
||||
yield self.notifier.on_new_room_event(
|
||||
event, extra_users=extra_users
|
||||
hosts = yield self.store.get_joined_hosts_for_room(
|
||||
event.room_id
|
||||
)
|
||||
if self.hs.hostname in hosts:
|
||||
try:
|
||||
yield self.store.store_room(
|
||||
room_id=event.room_id,
|
||||
room_creator_user_id="",
|
||||
is_public=False,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to get current state for room %s",
|
||||
event.room_id
|
||||
)
|
||||
|
||||
if not backfilled:
|
||||
extra_users = []
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
target_user_id = event.state_key
|
||||
target_user = self.hs.parse_userid(target_user_id)
|
||||
extra_users.append(target_user)
|
||||
|
||||
yield self.notifier.on_new_room_event(
|
||||
event, extra_users=extra_users
|
||||
)
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
if event.membership == Membership.JOIN:
|
||||
@ -189,79 +187,344 @@ class FederationHandler(BaseHandler):
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def backfill(self, dest, room_id, limit):
|
||||
pdus = yield self.replication_layer.backfill(dest, room_id, limit)
|
||||
extremities = yield self.store.get_oldest_events_in_room(room_id)
|
||||
|
||||
pdus = yield self.replication_layer.backfill(
|
||||
dest,
|
||||
room_id,
|
||||
limit,
|
||||
extremities=extremities,
|
||||
)
|
||||
|
||||
events = []
|
||||
|
||||
for pdu in pdus:
|
||||
event = self.pdu_codec.event_from_pdu(pdu)
|
||||
|
||||
# FIXME (erikj): Not sure this actually works :/
|
||||
yield self.state_handler.annotate_state_groups(event)
|
||||
|
||||
events.append(event)
|
||||
|
||||
yield self.store.persist_event(event, backfilled=True)
|
||||
|
||||
defer.returnValue(events)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_invite(self, target_host, event):
|
||||
pdu = yield self.replication_layer.send_invite(
|
||||
destination=target_host,
|
||||
context=event.room_id,
|
||||
event_id=event.event_id,
|
||||
pdu=self.pdu_codec.pdu_from_event(event)
|
||||
)
|
||||
|
||||
defer.returnValue(self.pdu_codec.event_from_pdu(pdu))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_event_auth(self, event_id):
|
||||
auth = yield self.store.get_auth_chain(event_id)
|
||||
defer.returnValue([self.pdu_codec.pdu_from_event(e) for e in auth])
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def do_invite_join(self, target_host, room_id, joinee, content, snapshot):
|
||||
|
||||
hosts = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
if self.hs.hostname in hosts:
|
||||
# We are already in the room.
|
||||
logger.debug("We're already in the room apparently")
|
||||
defer.returnValue(False)
|
||||
|
||||
# First get current state to see if we are already joined.
|
||||
try:
|
||||
yield self.replication_layer.get_state_for_context(
|
||||
target_host, room_id
|
||||
)
|
||||
|
||||
hosts = yield self.store.get_joined_hosts_for_room(room_id)
|
||||
if self.hs.hostname in hosts:
|
||||
# Oh, we were actually in the room already.
|
||||
logger.debug("We're already in the room apparently")
|
||||
defer.returnValue(False)
|
||||
except Exception:
|
||||
logger.exception("Failed to get current state")
|
||||
|
||||
new_event = self.event_factory.create_event(
|
||||
etype=InviteJoinEvent.TYPE,
|
||||
target_host=target_host,
|
||||
room_id=room_id,
|
||||
user_id=joinee,
|
||||
content=content
|
||||
pdu = yield self.replication_layer.make_join(
|
||||
target_host,
|
||||
room_id,
|
||||
joinee
|
||||
)
|
||||
|
||||
new_event.destinations = [target_host]
|
||||
logger.debug("Got response to make_join: %s", pdu)
|
||||
|
||||
snapshot.fill_out_prev_events(new_event)
|
||||
yield self.handle_new_event(new_event, snapshot)
|
||||
event = self.pdu_codec.event_from_pdu(pdu)
|
||||
|
||||
# TODO (erikj): Time out here.
|
||||
d = defer.Deferred()
|
||||
self.waiting_for_join_list.setdefault((joinee, room_id), []).append(d)
|
||||
reactor.callLater(10, d.cancel)
|
||||
# We should assert some things.
|
||||
assert(event.type == RoomMemberEvent.TYPE)
|
||||
assert(event.user_id == joinee)
|
||||
assert(event.state_key == joinee)
|
||||
assert(event.room_id == room_id)
|
||||
|
||||
event.outlier = False
|
||||
|
||||
self.room_queues[room_id] = []
|
||||
|
||||
try:
|
||||
yield d
|
||||
except defer.CancelledError:
|
||||
raise SynapseError(500, "Unable to join remote room")
|
||||
event.event_id = self.event_factory.create_event_id()
|
||||
event.content = content
|
||||
|
||||
try:
|
||||
yield self.store.store_room(
|
||||
room_id=room_id,
|
||||
room_creator_user_id="",
|
||||
is_public=False
|
||||
state = yield self.replication_layer.send_join(
|
||||
target_host,
|
||||
self.pdu_codec.pdu_from_event(event)
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
state = [self.pdu_codec.event_from_pdu(p) for p in state]
|
||||
|
||||
logger.debug("do_invite_join state: %s", state)
|
||||
|
||||
is_new_state = yield self.state_handler.annotate_state_groups(
|
||||
event,
|
||||
old_state=state
|
||||
)
|
||||
|
||||
logger.debug("do_invite_join event: %s", event)
|
||||
|
||||
try:
|
||||
yield self.store.store_room(
|
||||
room_id=room_id,
|
||||
room_creator_user_id="",
|
||||
is_public=False
|
||||
)
|
||||
except:
|
||||
# FIXME
|
||||
pass
|
||||
|
||||
for e in state:
|
||||
# FIXME: Auth these.
|
||||
e.outlier = True
|
||||
|
||||
yield self.state_handler.annotate_state_groups(
|
||||
e,
|
||||
)
|
||||
|
||||
yield self.store.persist_event(
|
||||
e,
|
||||
backfilled=False,
|
||||
is_new_state=False
|
||||
)
|
||||
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
backfilled=False,
|
||||
is_new_state=is_new_state
|
||||
)
|
||||
finally:
|
||||
room_queue = self.room_queues[room_id]
|
||||
del self.room_queues[room_id]
|
||||
|
||||
for p in room_queue:
|
||||
try:
|
||||
yield self.on_receive_pdu(p, backfilled=False)
|
||||
except:
|
||||
pass
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_make_join_request(self, context, user_id):
|
||||
event = self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
content={"membership": Membership.JOIN},
|
||||
room_id=context,
|
||||
user_id=user_id,
|
||||
state_key=user_id,
|
||||
)
|
||||
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
snapshot.fill_out_prev_events(event)
|
||||
|
||||
yield self.state_handler.annotate_state_groups(event)
|
||||
yield self.auth.add_auth_events(event)
|
||||
self.auth.check(event, raises=True)
|
||||
|
||||
pdu = self.pdu_codec.pdu_from_event(event)
|
||||
|
||||
defer.returnValue(pdu)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_send_join_request(self, origin, pdu):
|
||||
event = self.pdu_codec.event_from_pdu(pdu)
|
||||
|
||||
event.outlier = False
|
||||
|
||||
is_new_state = yield self.state_handler.annotate_state_groups(event)
|
||||
self.auth.check(event, raises=True)
|
||||
|
||||
# FIXME (erikj): All this is duplicated above :(
|
||||
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
backfilled=False,
|
||||
is_new_state=is_new_state
|
||||
)
|
||||
|
||||
extra_users = []
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
target_user_id = event.state_key
|
||||
target_user = self.hs.parse_userid(target_user_id)
|
||||
extra_users.append(target_user)
|
||||
|
||||
yield self.notifier.on_new_room_event(
|
||||
event, extra_users=extra_users
|
||||
)
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
if event.membership == Membership.JOIN:
|
||||
user = self.hs.parse_userid(event.state_key)
|
||||
self.distributor.fire(
|
||||
"user_joined_room", user=user, room_id=event.room_id
|
||||
)
|
||||
|
||||
new_pdu = self.pdu_codec.pdu_from_event(event)
|
||||
|
||||
destinations = set()
|
||||
|
||||
for k, s in event.state_events.items():
|
||||
try:
|
||||
if k[0] == RoomMemberEvent.TYPE:
|
||||
if s.content["membership"] == Membership.JOIN:
|
||||
destinations.add(
|
||||
self.hs.parse_userid(s.state_key).domain
|
||||
)
|
||||
except:
|
||||
logger.warn(
|
||||
"Failed to get destination from event %s", s.event_id
|
||||
)
|
||||
|
||||
new_pdu.destinations = list(destinations)
|
||||
|
||||
yield self.replication_layer.send_pdu(new_pdu)
|
||||
|
||||
auth_chain = yield self.store.get_auth_chain(event.event_id)
|
||||
pdu_auth_chain = [
|
||||
self.pdu_codec.pdu_from_event(e)
|
||||
for e in auth_chain
|
||||
]
|
||||
|
||||
defer.returnValue({
|
||||
"state": [
|
||||
self.pdu_codec.pdu_from_event(e)
|
||||
for e in event.state_events.values()
|
||||
],
|
||||
"auth_chain": pdu_auth_chain,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_invite_request(self, origin, pdu):
|
||||
event = self.pdu_codec.event_from_pdu(pdu)
|
||||
|
||||
event.outlier = True
|
||||
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
)
|
||||
)
|
||||
|
||||
yield self.state_handler.annotate_state_groups(event)
|
||||
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
backfilled=False,
|
||||
)
|
||||
|
||||
target_user = self.hs.parse_userid(event.state_key)
|
||||
yield self.notifier.on_new_room_event(
|
||||
event, extra_users=[target_user],
|
||||
)
|
||||
|
||||
defer.returnValue(self.pdu_codec.pdu_from_event(event))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_pdu(self, origin, room_id, event_id):
|
||||
yield run_on_reactor()
|
||||
|
||||
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
state_groups = yield self.store.get_state_groups(
|
||||
[event_id]
|
||||
)
|
||||
|
||||
if state_groups:
|
||||
_, state = state_groups.items().pop()
|
||||
results = {
|
||||
(e.type, e.state_key): e for e in state
|
||||
}
|
||||
|
||||
event = yield self.store.get_event(event_id)
|
||||
if hasattr(event, "state_key"):
|
||||
# Get previous state
|
||||
if hasattr(event, "replaces_state") and event.replaces_state:
|
||||
prev_event = yield self.store.get_event(
|
||||
event.replaces_state
|
||||
)
|
||||
results[(event.type, event.state_key)] = prev_event
|
||||
else:
|
||||
del results[(event.type, event.state_key)]
|
||||
|
||||
defer.returnValue(
|
||||
[
|
||||
self.pdu_codec.pdu_from_event(s)
|
||||
for s in results.values()
|
||||
]
|
||||
)
|
||||
else:
|
||||
defer.returnValue([])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_backfill_request(self, origin, context, pdu_list, limit):
|
||||
in_room = yield self.auth.check_host_in_room(context, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
events = yield self.store.get_backfill_events(
|
||||
context,
|
||||
pdu_list,
|
||||
limit
|
||||
)
|
||||
|
||||
defer.returnValue([
|
||||
self.pdu_codec.pdu_from_event(e)
|
||||
for e in events
|
||||
])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_persisted_pdu(self, origin, event_id):
|
||||
""" Get a PDU from the database with given origin and id.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a `Pdu`.
|
||||
"""
|
||||
event = yield self.store.get_event(
|
||||
event_id,
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if event:
|
||||
in_room = yield self.auth.check_host_in_room(
|
||||
event.room_id,
|
||||
origin
|
||||
)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
defer.returnValue(self.pdu_codec.pdu_from_event(event))
|
||||
else:
|
||||
defer.returnValue(None)
|
||||
|
||||
@log_function
|
||||
def get_min_depth_for_context(self, context):
|
||||
return self.store.get_min_depth(context)
|
||||
|
||||
@log_function
|
||||
def _on_user_joined(self, user, room_id):
|
||||
waiters = self.waiting_for_join_list.get((user.to_string(), room_id), [])
|
||||
waiters = self.waiting_for_join_list.get(
|
||||
(user.to_string(), room_id),
|
||||
[]
|
||||
)
|
||||
while waiters:
|
||||
waiters.pop().callback(None)
|
||||
|
@ -81,12 +81,11 @@ class MessageHandler(BaseHandler):
|
||||
user = self.hs.parse_userid(event.user_id)
|
||||
assert user.is_mine, "User must be our own: %s" % (user,)
|
||||
|
||||
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
|
||||
if not suppress_auth:
|
||||
yield self.auth.check(event, snapshot, raises=True)
|
||||
|
||||
yield self._on_new_room_event(event, snapshot)
|
||||
yield self._on_new_room_event(
|
||||
event, snapshot, suppress_auth=suppress_auth
|
||||
)
|
||||
|
||||
self.hs.get_handlers().presence_handler.bump_presence_active_time(
|
||||
user
|
||||
@ -142,16 +141,7 @@ class MessageHandler(BaseHandler):
|
||||
SynapseError if something went wrong.
|
||||
"""
|
||||
|
||||
snapshot = yield self.store.snapshot_room(
|
||||
event.room_id,
|
||||
event.user_id,
|
||||
state_type=event.type,
|
||||
state_key=event.state_key,
|
||||
)
|
||||
|
||||
yield self.auth.check(event, snapshot, raises=True)
|
||||
|
||||
yield self.state_handler.handle_new_event(event, snapshot)
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
|
||||
yield self._on_new_room_event(event, snapshot)
|
||||
|
||||
@ -201,7 +191,7 @@ class MessageHandler(BaseHandler):
|
||||
raise RoomError(
|
||||
403, "Member does not meet private room rules.")
|
||||
|
||||
data = yield self.store.get_current_state(
|
||||
data = yield self.state_handler.get_current_state(
|
||||
room_id, event_type, state_key
|
||||
)
|
||||
defer.returnValue(data)
|
||||
@ -219,9 +209,7 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_feedback(self, event):
|
||||
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
|
||||
|
||||
yield self.auth.check(event, snapshot, raises=True)
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
|
||||
# store message in db
|
||||
yield self._on_new_room_event(event, snapshot)
|
||||
@ -239,7 +227,7 @@ class MessageHandler(BaseHandler):
|
||||
yield self.auth.check_joined_room(room_id, user_id)
|
||||
|
||||
# TODO: This is duplicating logic from snapshot_all_rooms
|
||||
current_state = yield self.store.get_current_state(room_id)
|
||||
current_state = yield self.state_handler.get_current_state(room_id)
|
||||
defer.returnValue([self.hs.serialize_event(c) for c in current_state])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -316,7 +304,7 @@ class MessageHandler(BaseHandler):
|
||||
"end": end_token.to_string(),
|
||||
}
|
||||
|
||||
current_state = yield self.store.get_current_state(
|
||||
current_state = yield self.state_handler.get_current_state(
|
||||
event.room_id
|
||||
)
|
||||
d["state"] = [self.hs.serialize_event(c) for c in current_state]
|
||||
|
@ -17,7 +17,6 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.events.room import RoomMemberEvent
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
@ -153,10 +152,13 @@ class ProfileHandler(BaseHandler):
|
||||
if not user.is_mine:
|
||||
defer.returnValue(None)
|
||||
|
||||
(displayname, avatar_url) = yield defer.gatherResults([
|
||||
self.store.get_profile_displayname(user.localpart),
|
||||
self.store.get_profile_avatar_url(user.localpart),
|
||||
])
|
||||
(displayname, avatar_url) = yield defer.gatherResults(
|
||||
[
|
||||
self.store.get_profile_displayname(user.localpart),
|
||||
self.store.get_profile_avatar_url(user.localpart),
|
||||
],
|
||||
consumeErrors=True
|
||||
)
|
||||
|
||||
state["displayname"] = displayname
|
||||
state["avatar_url"] = avatar_url
|
||||
@ -196,10 +198,7 @@ class ProfileHandler(BaseHandler):
|
||||
)
|
||||
|
||||
for j in joins:
|
||||
snapshot = yield self.store.snapshot_room(
|
||||
j.room_id, j.state_key, RoomMemberEvent.TYPE,
|
||||
j.state_key
|
||||
)
|
||||
snapshot = yield self.store.snapshot_room(j)
|
||||
|
||||
content = {
|
||||
"membership": j.content["membership"],
|
||||
@ -218,5 +217,6 @@ class ProfileHandler(BaseHandler):
|
||||
user_id=j.state_key,
|
||||
)
|
||||
|
||||
yield self.state_handler.handle_new_event(new_event, snapshot)
|
||||
yield self._on_new_room_event(new_event, snapshot)
|
||||
yield self._on_new_room_event(
|
||||
new_event, snapshot, suppress_auth=True
|
||||
)
|
||||
|
@ -21,8 +21,7 @@ from synapse.api.constants import Membership, JoinRules
|
||||
from synapse.api.errors import StoreError, SynapseError
|
||||
from synapse.api.events.room import (
|
||||
RoomMemberEvent, RoomCreateEvent, RoomPowerLevelsEvent,
|
||||
RoomJoinRulesEvent, RoomAddStateLevelEvent, RoomTopicEvent,
|
||||
RoomSendEventLevelEvent, RoomOpsPowerLevelsEvent, RoomNameEvent,
|
||||
RoomTopicEvent, RoomNameEvent, RoomJoinRulesEvent,
|
||||
)
|
||||
from synapse.util import stringutils
|
||||
from ._base import BaseHandler
|
||||
@ -122,15 +121,13 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_event(event):
|
||||
snapshot = yield self.store.snapshot_room(
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
|
||||
logger.debug("Event: %s", event)
|
||||
|
||||
yield self.state_handler.handle_new_event(event, snapshot)
|
||||
yield self._on_new_room_event(event, snapshot, extra_users=[user])
|
||||
yield self._on_new_room_event(
|
||||
event, snapshot, extra_users=[user], suppress_auth=True
|
||||
)
|
||||
|
||||
for event in creation_events:
|
||||
yield handle_event(event)
|
||||
@ -141,7 +138,6 @@ class RoomCreationHandler(BaseHandler):
|
||||
etype=RoomNameEvent.TYPE,
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
required_power_level=50,
|
||||
content={"name": name},
|
||||
)
|
||||
|
||||
@ -153,7 +149,6 @@ class RoomCreationHandler(BaseHandler):
|
||||
etype=RoomTopicEvent.TYPE,
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
required_power_level=50,
|
||||
content={"topic": topic},
|
||||
)
|
||||
|
||||
@ -198,7 +193,6 @@ class RoomCreationHandler(BaseHandler):
|
||||
event_keys = {
|
||||
"room_id": room_id,
|
||||
"user_id": creator.to_string(),
|
||||
"required_power_level": 100,
|
||||
}
|
||||
|
||||
def create(etype, **content):
|
||||
@ -215,7 +209,21 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
power_levels_event = self.event_factory.create_event(
|
||||
etype=RoomPowerLevelsEvent.TYPE,
|
||||
content={creator.to_string(): 100, "default": 0},
|
||||
content={
|
||||
"users": {
|
||||
creator.to_string(): 100,
|
||||
},
|
||||
"users_default": 0,
|
||||
"events": {
|
||||
RoomNameEvent.TYPE: 100,
|
||||
RoomPowerLevelsEvent.TYPE: 100,
|
||||
},
|
||||
"events_default": 0,
|
||||
"state_default": 50,
|
||||
"ban": 50,
|
||||
"kick": 50,
|
||||
"redact": 50
|
||||
},
|
||||
**event_keys
|
||||
)
|
||||
|
||||
@ -225,30 +233,10 @@ class RoomCreationHandler(BaseHandler):
|
||||
join_rule=join_rule,
|
||||
)
|
||||
|
||||
add_state_event = create(
|
||||
etype=RoomAddStateLevelEvent.TYPE,
|
||||
level=100,
|
||||
)
|
||||
|
||||
send_event = create(
|
||||
etype=RoomSendEventLevelEvent.TYPE,
|
||||
level=0,
|
||||
)
|
||||
|
||||
ops = create(
|
||||
etype=RoomOpsPowerLevelsEvent.TYPE,
|
||||
ban_level=50,
|
||||
kick_level=50,
|
||||
redact_level=50,
|
||||
)
|
||||
|
||||
return [
|
||||
creation_event,
|
||||
power_levels_event,
|
||||
join_rules_event,
|
||||
add_state_event,
|
||||
send_event,
|
||||
ops,
|
||||
]
|
||||
|
||||
|
||||
@ -363,10 +351,8 @@ class RoomMemberHandler(BaseHandler):
|
||||
"""
|
||||
target_user_id = event.state_key
|
||||
|
||||
snapshot = yield self.store.snapshot_room(
|
||||
event.room_id, event.user_id,
|
||||
RoomMemberEvent.TYPE, target_user_id
|
||||
)
|
||||
snapshot = yield self.store.snapshot_room(event)
|
||||
|
||||
## TODO(markjh): get prev state from snapshot.
|
||||
prev_state = yield self.store.get_room_member(
|
||||
target_user_id, event.room_id
|
||||
@ -375,13 +361,6 @@ class RoomMemberHandler(BaseHandler):
|
||||
if prev_state:
|
||||
event.content["prev"] = prev_state.membership
|
||||
|
||||
# if prev_state and prev_state.membership == event.membership:
|
||||
# # treat this event as a NOOP.
|
||||
# if do_auth: # This is mainly to fix a unit test.
|
||||
# yield self.auth.check(event, raises=True)
|
||||
# defer.returnValue({})
|
||||
# return
|
||||
|
||||
room_id = event.room_id
|
||||
|
||||
# If we're trying to join a room then we have to do this differently
|
||||
@ -391,29 +370,17 @@ class RoomMemberHandler(BaseHandler):
|
||||
yield self._do_join(event, snapshot, do_auth=do_auth)
|
||||
else:
|
||||
# This is not a JOIN, so we can handle it normally.
|
||||
if do_auth:
|
||||
yield self.auth.check(event, snapshot, raises=True)
|
||||
|
||||
# If we're banning someone, set a req power level
|
||||
if event.membership == Membership.BAN:
|
||||
if not hasattr(event, "required_power_level") or event.required_power_level is None:
|
||||
# Add some default required_power_level
|
||||
user_level = yield self.store.get_power_level(
|
||||
event.room_id,
|
||||
event.user_id,
|
||||
)
|
||||
event.required_power_level = user_level
|
||||
|
||||
if prev_state and prev_state.membership == event.membership:
|
||||
# double same action, treat this event as a NOOP.
|
||||
defer.returnValue({})
|
||||
return
|
||||
|
||||
yield self.state_handler.handle_new_event(event, snapshot)
|
||||
yield self._do_local_membership_update(
|
||||
event,
|
||||
membership=event.content["membership"],
|
||||
snapshot=snapshot,
|
||||
do_auth=do_auth,
|
||||
)
|
||||
|
||||
defer.returnValue({"room_id": room_id})
|
||||
@ -443,10 +410,7 @@ class RoomMemberHandler(BaseHandler):
|
||||
content=content,
|
||||
)
|
||||
|
||||
snapshot = yield self.store.snapshot_room(
|
||||
room_id, joinee.to_string(), RoomMemberEvent.TYPE,
|
||||
joinee.to_string()
|
||||
)
|
||||
snapshot = yield self.store.snapshot_room(new_event)
|
||||
|
||||
yield self._do_join(new_event, snapshot, room_host=host, do_auth=True)
|
||||
|
||||
@ -502,14 +466,11 @@ class RoomMemberHandler(BaseHandler):
|
||||
if not have_joined:
|
||||
logger.debug("Doing normal join")
|
||||
|
||||
if do_auth:
|
||||
yield self.auth.check(event, snapshot, raises=True)
|
||||
|
||||
yield self.state_handler.handle_new_event(event, snapshot)
|
||||
yield self._do_local_membership_update(
|
||||
event,
|
||||
membership=event.content["membership"],
|
||||
snapshot=snapshot,
|
||||
do_auth=do_auth,
|
||||
)
|
||||
|
||||
user = self.hs.parse_userid(event.user_id)
|
||||
@ -553,26 +514,27 @@ class RoomMemberHandler(BaseHandler):
|
||||
|
||||
defer.returnValue([r.room_id for r in rooms])
|
||||
|
||||
def _do_local_membership_update(self, event, membership, snapshot):
|
||||
destinations = []
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_local_membership_update(self, event, membership, snapshot,
|
||||
do_auth):
|
||||
# If we're inviting someone, then we should also send it to that
|
||||
# HS.
|
||||
target_user_id = event.state_key
|
||||
target_user = self.hs.parse_userid(target_user_id)
|
||||
if membership == Membership.INVITE:
|
||||
host = target_user.domain
|
||||
destinations.append(host)
|
||||
if membership == Membership.INVITE and not target_user.is_mine:
|
||||
do_invite_host = target_user.domain
|
||||
else:
|
||||
do_invite_host = None
|
||||
|
||||
# Always include target domain
|
||||
host = target_user.domain
|
||||
destinations.append(host)
|
||||
|
||||
return self._on_new_room_event(
|
||||
event, snapshot, extra_destinations=destinations,
|
||||
extra_users=[target_user]
|
||||
yield self._on_new_room_event(
|
||||
event,
|
||||
snapshot,
|
||||
extra_users=[target_user],
|
||||
suppress_auth=(not do_auth),
|
||||
do_invite_host=do_invite_host,
|
||||
)
|
||||
|
||||
|
||||
class RoomListHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -18,6 +18,11 @@ from synapse.api.urls import CLIENT_PREFIX
|
||||
from synapse.rest.transactions import HttpTransactionStore
|
||||
import re
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def client_path_pattern(path_regex):
|
||||
"""Creates a regex compiled client path with the correct client path
|
||||
@ -62,6 +67,8 @@ class RestServlet(object):
|
||||
self.auth = hs.get_auth()
|
||||
self.txns = HttpTransactionStore()
|
||||
|
||||
self.validator = hs.get_event_validator()
|
||||
|
||||
def register(self, http_server):
|
||||
""" Register this servlet with the given HTTP server. """
|
||||
if hasattr(self, "PATTERN"):
|
||||
|
@ -20,6 +20,12 @@ from synapse.api.errors import SynapseError
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.rest.base import RestServlet, client_path_pattern
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class EventStreamRestServlet(RestServlet):
|
||||
PATTERN = client_path_pattern("/events$")
|
||||
@ -29,18 +35,22 @@ class EventStreamRestServlet(RestServlet):
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
auth_user = yield self.auth.get_user_by_req(request)
|
||||
try:
|
||||
handler = self.handlers.event_stream_handler
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
if "timeout" in request.args:
|
||||
try:
|
||||
timeout = int(request.args["timeout"][0])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "timeout must be in milliseconds.")
|
||||
|
||||
handler = self.handlers.event_stream_handler
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
if "timeout" in request.args:
|
||||
try:
|
||||
timeout = int(request.args["timeout"][0])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "timeout must be in milliseconds.")
|
||||
|
||||
chunk = yield handler.get_stream(auth_user.to_string(), pagin_config,
|
||||
timeout=timeout)
|
||||
chunk = yield handler.get_stream(
|
||||
auth_user.to_string(), pagin_config, timeout=timeout
|
||||
)
|
||||
except:
|
||||
logger.exception("Event stream failed")
|
||||
raise
|
||||
|
||||
defer.returnValue((200, chunk))
|
||||
|
||||
|
@ -138,7 +138,7 @@ class RoomStateEventRestServlet(RestServlet):
|
||||
raise SynapseError(
|
||||
404, "Event not found.", errcode=Codes.NOT_FOUND
|
||||
)
|
||||
defer.returnValue((200, data[0].get_dict()["content"]))
|
||||
defer.returnValue((200, data.get_dict()["content"]))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, event_type, state_key):
|
||||
@ -154,6 +154,9 @@ class RoomStateEventRestServlet(RestServlet):
|
||||
user_id=user.to_string(),
|
||||
state_key=urllib.unquote(state_key)
|
||||
)
|
||||
|
||||
self.validator.validate(event)
|
||||
|
||||
if event_type == RoomMemberEvent.TYPE:
|
||||
# membership events are special
|
||||
handler = self.handlers.room_member_handler
|
||||
@ -188,6 +191,8 @@ class RoomSendEventRestServlet(RestServlet):
|
||||
content=content
|
||||
)
|
||||
|
||||
self.validator.validate(event)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
yield msg_handler.send_message(event)
|
||||
|
||||
@ -253,6 +258,9 @@ class JoinRoomAliasServlet(RestServlet):
|
||||
user_id=user.to_string(),
|
||||
state_key=user.to_string()
|
||||
)
|
||||
|
||||
self.validator.validate(event)
|
||||
|
||||
handler = self.handlers.room_member_handler
|
||||
yield handler.change_membership(event)
|
||||
defer.returnValue((200, {}))
|
||||
@ -424,6 +432,9 @@ class RoomMembershipRestServlet(RestServlet):
|
||||
user_id=user.to_string(),
|
||||
state_key=state_key
|
||||
)
|
||||
|
||||
self.validator.validate(event)
|
||||
|
||||
handler = self.handlers.room_member_handler
|
||||
yield handler.change_membership(event)
|
||||
defer.returnValue((200, {}))
|
||||
@ -461,6 +472,8 @@ class RoomRedactEventRestServlet(RestServlet):
|
||||
redacts=urllib.unquote(event_id),
|
||||
)
|
||||
|
||||
self.validator.validate(event)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
yield msg_handler.send_message(event)
|
||||
|
||||
|
@ -22,13 +22,14 @@
|
||||
from synapse.federation import initialize_http_replication
|
||||
from synapse.api.events import serialize_event
|
||||
from synapse.api.events.factory import EventFactory
|
||||
from synapse.api.events.validator import EventValidator
|
||||
from synapse.notifier import Notifier
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.handlers import Handlers
|
||||
from synapse.rest import RestServletFactory
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage import DataStore
|
||||
from synapse.types import UserID, RoomAlias, RoomID
|
||||
from synapse.types import UserID, RoomAlias, RoomID, EventID
|
||||
from synapse.util import Clock
|
||||
from synapse.util.distributor import Distributor
|
||||
from synapse.util.lockutils import LockManager
|
||||
@ -80,6 +81,7 @@ class BaseHomeServer(object):
|
||||
'event_sources',
|
||||
'ratelimiter',
|
||||
'keyring',
|
||||
'event_validator',
|
||||
]
|
||||
|
||||
def __init__(self, hostname, **kwargs):
|
||||
@ -143,6 +145,11 @@ class BaseHomeServer(object):
|
||||
object."""
|
||||
return RoomID.from_string(s, hs=self)
|
||||
|
||||
def parse_eventid(self, s):
|
||||
"""Parse the string given by 's' as a Event ID and return a EventID
|
||||
object."""
|
||||
return EventID.from_string(s, hs=self)
|
||||
|
||||
def serialize_event(self, e):
|
||||
return serialize_event(self, e)
|
||||
|
||||
@ -218,6 +225,9 @@ class HomeServer(BaseHomeServer):
|
||||
def build_keyring(self):
|
||||
return Keyring(self)
|
||||
|
||||
def build_event_validator(self):
|
||||
return EventValidator(self)
|
||||
|
||||
def register_servlets(self):
|
||||
""" Register all servlets associated with this HomeServer.
|
||||
"""
|
||||
|
327
synapse/state.py
327
synapse/state.py
@ -16,11 +16,13 @@
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.federation.pdu_codec import encode_event_id, decode_event_id
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.api.events.room import RoomPowerLevelsEvent
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import hashlib
|
||||
|
||||
@ -35,230 +37,169 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
|
||||
|
||||
|
||||
class StateHandler(object):
|
||||
""" Repsonsible for doing state conflict resolution.
|
||||
""" Responsible for doing state conflict resolution.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self._replication = hs.get_replication_layer()
|
||||
self.server_name = hs.hostname
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def handle_new_event(self, event, snapshot):
|
||||
""" Given an event this works out if a) we have sufficient power level
|
||||
to update the state and b) works out what the prev_state should be.
|
||||
def annotate_state_groups(self, event, old_state=None):
|
||||
yield run_on_reactor()
|
||||
|
||||
Returns:
|
||||
Deferred: Resolved with a boolean indicating if we succesfully
|
||||
updated the state.
|
||||
if old_state:
|
||||
event.state_group = None
|
||||
event.old_state_events = {
|
||||
(s.type, s.state_key): s for s in old_state
|
||||
}
|
||||
event.state_events = event.old_state_events
|
||||
|
||||
Raised:
|
||||
AuthError
|
||||
"""
|
||||
# This needs to be done in a transaction.
|
||||
if hasattr(event, "state_key"):
|
||||
event.state_events[(event.type, event.state_key)] = event
|
||||
|
||||
if not hasattr(event, "state_key"):
|
||||
defer.returnValue(False)
|
||||
return
|
||||
|
||||
key = KeyStateTuple(
|
||||
event.room_id,
|
||||
event.type,
|
||||
_get_state_key_from_event(event)
|
||||
)
|
||||
if hasattr(event, "outlier") and event.outlier:
|
||||
event.state_group = None
|
||||
event.old_state_events = None
|
||||
event.state_events = {}
|
||||
defer.returnValue(False)
|
||||
return
|
||||
|
||||
# Now I need to fill out the prev state and work out if it has auth
|
||||
# (w.r.t. to power levels)
|
||||
ids = [e for e, _ in event.prev_events]
|
||||
|
||||
snapshot.fill_out_prev_events(event)
|
||||
ret = yield self.resolve_state_groups(ids)
|
||||
state_group, new_state = ret
|
||||
|
||||
event.prev_events = [
|
||||
e for e in event.prev_events if e != event.event_id
|
||||
event.old_state_events = copy.deepcopy(new_state)
|
||||
|
||||
if hasattr(event, "state_key"):
|
||||
key = (event.type, event.state_key)
|
||||
if key in new_state:
|
||||
event.replaces_state = new_state[key].event_id
|
||||
new_state[key] = event
|
||||
elif state_group:
|
||||
event.state_group = state_group
|
||||
event.state_events = new_state
|
||||
defer.returnValue(False)
|
||||
|
||||
event.state_group = None
|
||||
event.state_events = new_state
|
||||
|
||||
defer.returnValue(hasattr(event, "state_key"))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state(self, room_id, event_type=None, state_key=""):
|
||||
events = yield self.store.get_latest_events_in_room(room_id)
|
||||
|
||||
event_ids = [
|
||||
e_id
|
||||
for e_id, _, _ in events
|
||||
]
|
||||
|
||||
current_state = snapshot.prev_state_pdu
|
||||
res = yield self.resolve_state_groups(event_ids)
|
||||
|
||||
if current_state:
|
||||
event.prev_state = encode_event_id(
|
||||
current_state.pdu_id, current_state.origin
|
||||
)
|
||||
|
||||
# TODO check current_state to see if the min power level is less
|
||||
# than the power level of the user
|
||||
# power_level = self._get_power_level_for_event(event)
|
||||
|
||||
pdu_id, origin = decode_event_id(event.event_id, self.server_name)
|
||||
|
||||
yield self.store.update_current_state(
|
||||
pdu_id=pdu_id,
|
||||
origin=origin,
|
||||
context=key.context,
|
||||
pdu_type=key.type,
|
||||
state_key=key.state_key
|
||||
)
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def handle_new_state(self, new_pdu):
|
||||
""" Apply conflict resolution to `new_pdu`.
|
||||
|
||||
This should be called on every new state pdu, regardless of whether or
|
||||
not there is a conflict.
|
||||
|
||||
This function is safe against the race of it getting called with two
|
||||
`PDU`s trying to update the same state.
|
||||
"""
|
||||
|
||||
# This needs to be done in a transaction.
|
||||
|
||||
is_new = yield self._handle_new_state(new_pdu)
|
||||
|
||||
logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin)
|
||||
|
||||
if is_new:
|
||||
yield self.store.update_current_state(
|
||||
pdu_id=new_pdu.pdu_id,
|
||||
origin=new_pdu.origin,
|
||||
context=new_pdu.context,
|
||||
pdu_type=new_pdu.pdu_type,
|
||||
state_key=new_pdu.state_key
|
||||
)
|
||||
|
||||
defer.returnValue(is_new)
|
||||
|
||||
def _get_power_level_for_event(self, event):
|
||||
# return self._persistence.get_power_level_for_user(event.room_id,
|
||||
# event.sender)
|
||||
return event.power_level
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _handle_new_state(self, new_pdu):
|
||||
tree, missing_branch = yield self.store.get_unresolved_state_tree(
|
||||
new_pdu
|
||||
)
|
||||
new_branch, current_branch = tree
|
||||
|
||||
logger.debug(
|
||||
"_handle_new_state new=%s, current=%s",
|
||||
new_branch, current_branch
|
||||
)
|
||||
|
||||
if missing_branch is not None:
|
||||
# We're missing some PDUs. Fetch them.
|
||||
# TODO (erikj): Limit this.
|
||||
missing_prev = tree[missing_branch][-1]
|
||||
|
||||
pdu_id = missing_prev.prev_state_id
|
||||
origin = missing_prev.prev_state_origin
|
||||
|
||||
is_missing = yield self.store.get_pdu(pdu_id, origin) is None
|
||||
if not is_missing:
|
||||
raise Exception("Conflict resolution failed")
|
||||
|
||||
yield self._replication.get_pdu(
|
||||
destination=missing_prev.origin,
|
||||
pdu_origin=origin,
|
||||
pdu_id=pdu_id,
|
||||
outlier=True
|
||||
)
|
||||
|
||||
updated_current = yield self._handle_new_state(new_pdu)
|
||||
defer.returnValue(updated_current)
|
||||
|
||||
if not current_branch:
|
||||
# There is no current state
|
||||
defer.returnValue(True)
|
||||
if event_type:
|
||||
defer.returnValue(res[1].get((event_type, state_key)))
|
||||
return
|
||||
|
||||
n = new_branch[-1]
|
||||
c = current_branch[-1]
|
||||
defer.returnValue(res[1].values())
|
||||
|
||||
common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def resolve_state_groups(self, event_ids):
|
||||
state_groups = yield self.store.get_state_groups(
|
||||
event_ids
|
||||
)
|
||||
|
||||
if common_ancestor:
|
||||
# We found a common ancestor!
|
||||
group_names = set(state_groups.keys())
|
||||
if len(group_names) == 1:
|
||||
name, state_list = state_groups.items().pop()
|
||||
state = {
|
||||
(e.type, e.state_key): e
|
||||
for e in state_list
|
||||
}
|
||||
defer.returnValue((name, state))
|
||||
|
||||
if len(current_branch) == 1:
|
||||
# This is a direct clobber so we can just...
|
||||
defer.returnValue(True)
|
||||
state = {}
|
||||
for group, g_state in state_groups.items():
|
||||
for s in g_state:
|
||||
state.setdefault(
|
||||
(s.type, s.state_key),
|
||||
{}
|
||||
)[s.event_id] = s
|
||||
|
||||
unconflicted_state = {
|
||||
k: v.values()[0] for k, v in state.items()
|
||||
if len(v.values()) == 1
|
||||
}
|
||||
|
||||
conflicted_state = {
|
||||
k: v.values()
|
||||
for k, v in state.items()
|
||||
if len(v.values()) > 1
|
||||
}
|
||||
|
||||
try:
|
||||
new_state = {}
|
||||
new_state.update(unconflicted_state)
|
||||
for key, events in conflicted_state.items():
|
||||
new_state[key] = self._resolve_state_events(events)
|
||||
except:
|
||||
logger.exception("Failed to resolve state")
|
||||
raise
|
||||
|
||||
defer.returnValue((None, new_state))
|
||||
|
||||
def _get_power_level_from_event_state(self, event, user_id):
|
||||
if hasattr(event, "old_state_events") and event.old_state_events:
|
||||
key = (RoomPowerLevelsEvent.TYPE, "", )
|
||||
power_level_event = event.old_state_events.get(key)
|
||||
level = None
|
||||
if power_level_event:
|
||||
level = power_level_event.content.get("users", {}).get(
|
||||
user_id
|
||||
)
|
||||
if not level:
|
||||
level = power_level_event.content.get("users_default", 0)
|
||||
|
||||
return level
|
||||
else:
|
||||
# We didn't find a common ancestor. This is probably fine.
|
||||
pass
|
||||
return 0
|
||||
|
||||
result = yield self._do_conflict_res(
|
||||
new_branch, current_branch, common_ancestor
|
||||
)
|
||||
defer.returnValue(result)
|
||||
@log_function
|
||||
def _resolve_state_events(self, events):
|
||||
curr_events = events
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_conflict_res(self, new_branch, current_branch, common_ancestor):
|
||||
conflict_res = [
|
||||
self._do_power_level_conflict_res,
|
||||
self._do_chain_length_conflict_res,
|
||||
self._do_hash_conflict_res,
|
||||
new_powers = [
|
||||
self._get_power_level_from_event_state(e, e.user_id)
|
||||
for e in curr_events
|
||||
]
|
||||
|
||||
for algo in conflict_res:
|
||||
new_res, curr_res = yield defer.maybeDeferred(
|
||||
algo,
|
||||
new_branch, current_branch, common_ancestor
|
||||
)
|
||||
new_powers = [
|
||||
int(p) if p else 0 for p in new_powers
|
||||
]
|
||||
|
||||
if new_res < curr_res:
|
||||
defer.returnValue(False)
|
||||
elif new_res > curr_res:
|
||||
defer.returnValue(True)
|
||||
max_power = max(new_powers)
|
||||
|
||||
raise Exception("Conflict resolution failed.")
|
||||
curr_events = [
|
||||
z[0] for z in zip(curr_events, new_powers)
|
||||
if z[1] == max_power
|
||||
]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_power_level_conflict_res(self, new_branch, current_branch,
|
||||
common_ancestor):
|
||||
new_powers_deferreds = []
|
||||
for e in new_branch[:-1] if common_ancestor else new_branch:
|
||||
if hasattr(e, "user_id"):
|
||||
new_powers_deferreds.append(
|
||||
self.store.get_power_level(e.context, e.user_id)
|
||||
)
|
||||
|
||||
current_powers_deferreds = []
|
||||
for e in current_branch[:-1] if common_ancestor else current_branch:
|
||||
if hasattr(e, "user_id"):
|
||||
current_powers_deferreds.append(
|
||||
self.store.get_power_level(e.context, e.user_id)
|
||||
)
|
||||
|
||||
new_powers = yield defer.gatherResults(
|
||||
new_powers_deferreds,
|
||||
consumeErrors=True
|
||||
)
|
||||
|
||||
current_powers = yield defer.gatherResults(
|
||||
current_powers_deferreds,
|
||||
consumeErrors=True
|
||||
)
|
||||
|
||||
max_power_new = max(new_powers)
|
||||
max_power_current = max(current_powers)
|
||||
|
||||
defer.returnValue(
|
||||
(max_power_new, max_power_current)
|
||||
)
|
||||
|
||||
def _do_chain_length_conflict_res(self, new_branch, current_branch,
|
||||
common_ancestor):
|
||||
return (len(new_branch), len(current_branch))
|
||||
|
||||
def _do_hash_conflict_res(self, new_branch, current_branch,
|
||||
common_ancestor):
|
||||
new_str = "".join([p.pdu_id + p.origin for p in new_branch])
|
||||
c_str = "".join([p.pdu_id + p.origin for p in current_branch])
|
||||
if not curr_events:
|
||||
raise RuntimeError("Max didn't get a max?")
|
||||
elif len(curr_events) == 1:
|
||||
return curr_events[0]
|
||||
|
||||
# TODO: For now, just choose the one with the largest event_id.
|
||||
return (
|
||||
hashlib.sha1(new_str).hexdigest(),
|
||||
hashlib.sha1(c_str).hexdigest()
|
||||
sorted(
|
||||
curr_events,
|
||||
key=lambda e: hashlib.sha1(
|
||||
e.event_id + e.user_id + e.room_id + e.type
|
||||
).hexdigest()
|
||||
)[0]
|
||||
)
|
||||
|
@ -16,14 +16,7 @@
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.events.room import (
|
||||
RoomMemberEvent, RoomTopicEvent, FeedbackEvent,
|
||||
# RoomConfigEvent,
|
||||
RoomNameEvent,
|
||||
RoomJoinRulesEvent,
|
||||
RoomPowerLevelsEvent,
|
||||
RoomAddStateLevelEvent,
|
||||
RoomSendEventLevelEvent,
|
||||
RoomOpsPowerLevelsEvent,
|
||||
RoomMemberEvent, RoomTopicEvent, FeedbackEvent, RoomNameEvent,
|
||||
RoomRedactionEvent,
|
||||
)
|
||||
|
||||
@ -37,9 +30,17 @@ from .registration import RegistrationStore
|
||||
from .room import RoomStore
|
||||
from .roommember import RoomMemberStore
|
||||
from .stream import StreamStore
|
||||
from .pdu import StatePduStore, PduStore, PdusTable
|
||||
from .transactions import TransactionStore
|
||||
from .keys import KeyStore
|
||||
from .event_federation import EventFederationStore
|
||||
|
||||
from .state import StateStore
|
||||
from .signatures import SignatureStore
|
||||
|
||||
from syutil.base64util import decode_base64
|
||||
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
|
||||
|
||||
import json
|
||||
import logging
|
||||
@ -51,7 +52,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
SCHEMAS = [
|
||||
"transactions",
|
||||
"pdu",
|
||||
"users",
|
||||
"profiles",
|
||||
"presence",
|
||||
@ -59,6 +59,9 @@ SCHEMAS = [
|
||||
"room_aliases",
|
||||
"keys",
|
||||
"redactions",
|
||||
"state",
|
||||
"event_edges",
|
||||
"event_signatures",
|
||||
]
|
||||
|
||||
|
||||
@ -73,10 +76,12 @@ class _RollbackButIsFineException(Exception):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DataStore(RoomMemberStore, RoomStore,
|
||||
RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
|
||||
PresenceStore, PduStore, StatePduStore, TransactionStore,
|
||||
DirectoryStore, KeyStore):
|
||||
PresenceStore, TransactionStore,
|
||||
DirectoryStore, KeyStore, StateStore, SignatureStore,
|
||||
EventFederationStore, ):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(DataStore, self).__init__(hs)
|
||||
@ -88,8 +93,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def persist_event(self, event=None, backfilled=False, pdu=None,
|
||||
is_new_state=True):
|
||||
def persist_event(self, event, backfilled=False, is_new_state=True):
|
||||
stream_ordering = None
|
||||
if backfilled:
|
||||
if not self.min_token_deferred.called:
|
||||
@ -99,8 +103,8 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
|
||||
try:
|
||||
yield self.runInteraction(
|
||||
self._persist_pdu_event_txn,
|
||||
pdu=pdu,
|
||||
"persist_event",
|
||||
self._persist_event_txn,
|
||||
event=event,
|
||||
backfilled=backfilled,
|
||||
stream_ordering=stream_ordering,
|
||||
@ -119,7 +123,8 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
"type",
|
||||
"room_id",
|
||||
"content",
|
||||
"unrecognized_keys"
|
||||
"unrecognized_keys",
|
||||
"depth",
|
||||
],
|
||||
allow_none=allow_none,
|
||||
)
|
||||
@ -130,42 +135,6 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
event = self._parse_event_from_row(events_dict)
|
||||
defer.returnValue(event)
|
||||
|
||||
def _persist_pdu_event_txn(self, txn, pdu=None, event=None,
|
||||
backfilled=False, stream_ordering=None,
|
||||
is_new_state=True):
|
||||
if pdu is not None:
|
||||
self._persist_event_pdu_txn(txn, pdu)
|
||||
if event is not None:
|
||||
return self._persist_event_txn(
|
||||
txn, event, backfilled, stream_ordering,
|
||||
is_new_state=is_new_state,
|
||||
)
|
||||
|
||||
def _persist_event_pdu_txn(self, txn, pdu):
|
||||
cols = dict(pdu.__dict__)
|
||||
unrec_keys = dict(pdu.unrecognized_keys)
|
||||
del cols["content"]
|
||||
del cols["prev_pdus"]
|
||||
cols["content_json"] = json.dumps(pdu.content)
|
||||
|
||||
unrec_keys.update({
|
||||
k: v for k, v in cols.items()
|
||||
if k not in PdusTable.fields
|
||||
})
|
||||
|
||||
cols["unrecognized_keys"] = json.dumps(unrec_keys)
|
||||
|
||||
cols["ts"] = cols.pop("origin_server_ts")
|
||||
|
||||
logger.debug("Persisting: %s", repr(cols))
|
||||
|
||||
if pdu.is_state:
|
||||
self._persist_state_txn(txn, pdu.prev_pdus, cols)
|
||||
else:
|
||||
self._persist_pdu_txn(txn, pdu.prev_pdus, cols)
|
||||
|
||||
self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth)
|
||||
|
||||
@log_function
|
||||
def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
|
||||
is_new_state=True):
|
||||
@ -177,19 +146,13 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
self._store_room_name_txn(txn, event)
|
||||
elif event.type == RoomTopicEvent.TYPE:
|
||||
self._store_room_topic_txn(txn, event)
|
||||
elif event.type == RoomJoinRulesEvent.TYPE:
|
||||
self._store_join_rule(txn, event)
|
||||
elif event.type == RoomPowerLevelsEvent.TYPE:
|
||||
self._store_power_levels(txn, event)
|
||||
elif event.type == RoomAddStateLevelEvent.TYPE:
|
||||
self._store_add_state_level(txn, event)
|
||||
elif event.type == RoomSendEventLevelEvent.TYPE:
|
||||
self._store_send_event_level(txn, event)
|
||||
elif event.type == RoomOpsPowerLevelsEvent.TYPE:
|
||||
self._store_ops_level(txn, event)
|
||||
elif event.type == RoomRedactionEvent.TYPE:
|
||||
self._store_redaction(txn, event)
|
||||
|
||||
outlier = False
|
||||
if hasattr(event, "outlier"):
|
||||
outlier = event.outlier
|
||||
|
||||
vals = {
|
||||
"topological_ordering": event.depth,
|
||||
"event_id": event.event_id,
|
||||
@ -197,25 +160,33 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
"room_id": event.room_id,
|
||||
"content": json.dumps(event.content),
|
||||
"processed": True,
|
||||
"outlier": outlier,
|
||||
"depth": event.depth,
|
||||
}
|
||||
|
||||
if stream_ordering is not None:
|
||||
vals["stream_ordering"] = stream_ordering
|
||||
|
||||
if hasattr(event, "outlier"):
|
||||
vals["outlier"] = event.outlier
|
||||
else:
|
||||
vals["outlier"] = False
|
||||
|
||||
unrec = {
|
||||
k: v
|
||||
for k, v in event.get_full_dict().items()
|
||||
if k not in vals.keys() and k not in ["redacted", "redacted_because"]
|
||||
if k not in vals.keys() and k not in [
|
||||
"redacted",
|
||||
"redacted_because",
|
||||
"signatures",
|
||||
"hashes",
|
||||
"prev_events",
|
||||
]
|
||||
}
|
||||
vals["unrecognized_keys"] = json.dumps(unrec)
|
||||
|
||||
try:
|
||||
self._simple_insert_txn(txn, "events", vals)
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"events",
|
||||
vals,
|
||||
or_replace=(not outlier),
|
||||
)
|
||||
except:
|
||||
logger.warn(
|
||||
"Failed to persist, probably duplicate: %s",
|
||||
@ -224,6 +195,16 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
)
|
||||
raise _RollbackButIsFineException("_persist_event")
|
||||
|
||||
self._handle_prev_events(
|
||||
txn,
|
||||
outlier=outlier,
|
||||
event_id=event.event_id,
|
||||
prev_events=event.prev_events,
|
||||
room_id=event.room_id,
|
||||
)
|
||||
|
||||
self._store_state_groups_txn(txn, event)
|
||||
|
||||
is_state = hasattr(event, "state_key") and event.state_key is not None
|
||||
if is_new_state and is_state:
|
||||
vals = {
|
||||
@ -233,8 +214,8 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
"state_key": event.state_key,
|
||||
}
|
||||
|
||||
if hasattr(event, "prev_state"):
|
||||
vals["prev_state"] = event.prev_state
|
||||
if hasattr(event, "replaces_state"):
|
||||
vals["prev_state"] = event.replaces_state
|
||||
|
||||
self._simple_insert_txn(txn, "state_events", vals)
|
||||
|
||||
@ -249,6 +230,81 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
}
|
||||
)
|
||||
|
||||
for e_id, h in event.prev_state:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_edges",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"prev_event_id": e_id,
|
||||
"room_id": event.room_id,
|
||||
"is_state": 1,
|
||||
},
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
if not backfilled:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_forward_extremities",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
}
|
||||
)
|
||||
|
||||
for prev_state_id, _ in event.prev_state:
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="state_forward_extremities",
|
||||
keyvalues={
|
||||
"event_id": prev_state_id,
|
||||
}
|
||||
)
|
||||
|
||||
for hash_alg, hash_base64 in event.hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_event_content_hash_txn(
|
||||
txn, event.event_id, hash_alg, hash_bytes,
|
||||
)
|
||||
|
||||
if hasattr(event, "signatures"):
|
||||
signatures = event.signatures.get(event.origin, {})
|
||||
|
||||
for key_id, signature_base64 in signatures.items():
|
||||
signature_bytes = decode_base64(signature_base64)
|
||||
self._store_event_origin_signature_txn(
|
||||
txn, event.event_id, event.origin, key_id, signature_bytes,
|
||||
)
|
||||
|
||||
for prev_event_id, prev_hashes in event.prev_events:
|
||||
for alg, hash_base64 in prev_hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_prev_event_hash_txn(
|
||||
txn, event.event_id, prev_event_id, alg, hash_bytes
|
||||
)
|
||||
|
||||
for auth_id, _ in event.auth_events:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_auth",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"auth_id": auth_id,
|
||||
},
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
|
||||
self._store_event_reference_hash_txn(
|
||||
txn, event.event_id, ref_alg, ref_hash_bytes
|
||||
)
|
||||
|
||||
self._update_min_depth_for_room_txn(txn, event.room_id, event.depth)
|
||||
|
||||
def _store_redaction(self, txn, event):
|
||||
txn.execute(
|
||||
"INSERT OR IGNORE INTO redactions "
|
||||
@ -319,7 +375,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
],
|
||||
)
|
||||
|
||||
def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
|
||||
def snapshot_room(self, event):
|
||||
"""Snapshot the room for an update by a user
|
||||
Args:
|
||||
room_id (synapse.types.RoomId): The room to snapshot.
|
||||
@ -330,29 +386,33 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
synapse.storage.Snapshot: A snapshot of the state of the room.
|
||||
"""
|
||||
def _snapshot(txn):
|
||||
membership_state = self._get_room_member(txn, user_id, room_id)
|
||||
prev_pdus = self._get_latest_pdus_in_context(
|
||||
txn, room_id
|
||||
prev_events = self._get_latest_events_in_room(
|
||||
txn,
|
||||
event.room_id
|
||||
)
|
||||
if state_type is not None and state_key is not None:
|
||||
prev_state_pdu = self._get_current_state_pdu(
|
||||
txn, room_id, state_type, state_key
|
||||
|
||||
prev_state = None
|
||||
state_key = None
|
||||
if hasattr(event, "state_key"):
|
||||
state_key = event.state_key
|
||||
prev_state = self._get_latest_state_in_room(
|
||||
txn,
|
||||
event.room_id,
|
||||
type=event.type,
|
||||
state_key=state_key,
|
||||
)
|
||||
else:
|
||||
prev_state_pdu = None
|
||||
|
||||
return Snapshot(
|
||||
store=self,
|
||||
room_id=room_id,
|
||||
user_id=user_id,
|
||||
prev_pdus=prev_pdus,
|
||||
membership_state=membership_state,
|
||||
state_type=state_type,
|
||||
room_id=event.room_id,
|
||||
user_id=event.user_id,
|
||||
prev_events=prev_events,
|
||||
prev_state=prev_state,
|
||||
state_type=event.type,
|
||||
state_key=state_key,
|
||||
prev_state_pdu=prev_state_pdu,
|
||||
)
|
||||
|
||||
return self.runInteraction(_snapshot)
|
||||
return self.runInteraction("snapshot_room", _snapshot)
|
||||
|
||||
|
||||
class Snapshot(object):
|
||||
@ -361,7 +421,7 @@ class Snapshot(object):
|
||||
store (DataStore): The datastore.
|
||||
room_id (RoomId): The room of the snapshot.
|
||||
user_id (UserId): The user this snapshot is for.
|
||||
prev_pdus (list): The list of PDU ids this snapshot is after.
|
||||
prev_events (list): The list of event ids this snapshot is after.
|
||||
membership_state (RoomMemberEvent): The current state of the user in
|
||||
the room.
|
||||
state_type (str, optional): State type captured by the snapshot
|
||||
@ -370,32 +430,30 @@ class Snapshot(object):
|
||||
the previous value of the state type and key in the room.
|
||||
"""
|
||||
|
||||
def __init__(self, store, room_id, user_id, prev_pdus,
|
||||
membership_state, state_type=None, state_key=None,
|
||||
prev_state_pdu=None):
|
||||
def __init__(self, store, room_id, user_id, prev_events,
|
||||
prev_state, state_type=None, state_key=None):
|
||||
self.store = store
|
||||
self.room_id = room_id
|
||||
self.user_id = user_id
|
||||
self.prev_pdus = prev_pdus
|
||||
self.membership_state = membership_state
|
||||
self.prev_events = prev_events
|
||||
self.prev_state = prev_state
|
||||
self.state_type = state_type
|
||||
self.state_key = state_key
|
||||
self.prev_state_pdu = prev_state_pdu
|
||||
|
||||
def fill_out_prev_events(self, event):
|
||||
if hasattr(event, "prev_events"):
|
||||
return
|
||||
if not hasattr(event, "prev_events"):
|
||||
event.prev_events = [
|
||||
(event_id, hashes)
|
||||
for event_id, hashes, _ in self.prev_events
|
||||
]
|
||||
|
||||
es = [
|
||||
"%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus
|
||||
]
|
||||
if self.prev_events:
|
||||
event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
|
||||
else:
|
||||
event.depth = 0
|
||||
|
||||
event.prev_events = [e for e in es if e != event.event_id]
|
||||
|
||||
if self.prev_pdus:
|
||||
event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1
|
||||
else:
|
||||
event.depth = 0
|
||||
if not hasattr(event, "prev_state") and self.prev_state is not None:
|
||||
event.prev_state = self.prev_state
|
||||
|
||||
|
||||
def schema_path(schema):
|
||||
@ -436,11 +494,13 @@ def prepare_database(db_conn):
|
||||
user_version = row[0]
|
||||
|
||||
if user_version > SCHEMA_VERSION:
|
||||
raise ValueError("Cannot use this database as it is too " +
|
||||
raise ValueError(
|
||||
"Cannot use this database as it is too " +
|
||||
"new for the server to understand"
|
||||
)
|
||||
elif user_version < SCHEMA_VERSION:
|
||||
logging.info("Upgrading database from version %d",
|
||||
logging.info(
|
||||
"Upgrading database from version %d",
|
||||
user_version
|
||||
)
|
||||
|
||||
@ -452,13 +512,13 @@ def prepare_database(db_conn):
|
||||
db_conn.commit()
|
||||
|
||||
else:
|
||||
sql_script = "BEGIN TRANSACTION;"
|
||||
sql_script = "BEGIN TRANSACTION;\n"
|
||||
for sql_loc in SCHEMAS:
|
||||
sql_script += read_schema(sql_loc)
|
||||
sql_script += "\n"
|
||||
sql_script += "COMMIT TRANSACTION;"
|
||||
c.executescript(sql_script)
|
||||
db_conn.commit()
|
||||
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
|
||||
|
||||
c.close()
|
||||
|
||||
|
@ -14,59 +14,69 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.api.events.utils import prune_event
|
||||
from synapse.util.logutils import log_function
|
||||
from syutil.base64util import encode_base64
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
sql_logger = logging.getLogger("synapse.storage.SQL")
|
||||
transaction_logger = logging.getLogger("synapse.storage.txn")
|
||||
|
||||
|
||||
class LoggingTransaction(object):
|
||||
"""An object that almost-transparently proxies for the 'txn' object
|
||||
passed to the constructor. Adds logging to the .execute() method."""
|
||||
__slots__ = ["txn"]
|
||||
__slots__ = ["txn", "name"]
|
||||
|
||||
def __init__(self, txn):
|
||||
def __init__(self, txn, name):
|
||||
object.__setattr__(self, "txn", txn)
|
||||
object.__setattr__(self, "name", name)
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name == "execute":
|
||||
return object.__getattribute__(self, "execute")
|
||||
|
||||
return getattr(object.__getattribute__(self, "txn"), name)
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.txn, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
setattr(object.__getattribute__(self, "txn"), name, value)
|
||||
setattr(self.txn, name, value)
|
||||
|
||||
def execute(self, sql, *args, **kwargs):
|
||||
# TODO(paul): Maybe use 'info' and 'debug' for values?
|
||||
sql_logger.debug("[SQL] %s", sql)
|
||||
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
|
||||
try:
|
||||
if args and args[0]:
|
||||
values = args[0]
|
||||
sql_logger.debug("[SQL values] " +
|
||||
", ".join(("<%s>",) * len(values)), *values)
|
||||
sql_logger.debug(
|
||||
"[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)),
|
||||
self.name,
|
||||
*values
|
||||
)
|
||||
except:
|
||||
# Don't let logging failures stop SQL from working
|
||||
pass
|
||||
|
||||
# TODO(paul): Here would be an excellent place to put some timing
|
||||
# measurements, and log (warning?) slow queries.
|
||||
return object.__getattribute__(self, "txn").execute(
|
||||
sql, *args, **kwargs
|
||||
)
|
||||
start = time.clock() * 1000
|
||||
try:
|
||||
return self.txn.execute(
|
||||
sql, *args, **kwargs
|
||||
)
|
||||
except:
|
||||
logger.exception("[SQL FAIL] {%s}", self.name)
|
||||
raise
|
||||
finally:
|
||||
end = time.clock() * 1000
|
||||
sql_logger.debug("[SQL time] {%s} %f", self.name, end - start)
|
||||
|
||||
|
||||
class SQLBaseStore(object):
|
||||
_TXN_ID = 0
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
@ -74,10 +84,30 @@ class SQLBaseStore(object):
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self._clock = hs.get_clock()
|
||||
|
||||
def runInteraction(self, func, *args, **kwargs):
|
||||
def runInteraction(self, desc, func, *args, **kwargs):
|
||||
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
||||
def inner_func(txn, *args, **kwargs):
|
||||
return func(LoggingTransaction(txn), *args, **kwargs)
|
||||
start = time.clock() * 1000
|
||||
txn_id = SQLBaseStore._TXN_ID
|
||||
|
||||
# We don't really need these to be unique, so lets stop it from
|
||||
# growing really large.
|
||||
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
|
||||
|
||||
name = "%s-%x" % (desc, txn_id, )
|
||||
|
||||
transaction_logger.debug("[TXN START] {%s}", name)
|
||||
try:
|
||||
return func(LoggingTransaction(txn, name), *args, **kwargs)
|
||||
except:
|
||||
logger.exception("[TXN FAIL] {%s}", name)
|
||||
raise
|
||||
finally:
|
||||
end = time.clock() * 1000
|
||||
transaction_logger.debug(
|
||||
"[TXN END] {%s} %f",
|
||||
name, end - start
|
||||
)
|
||||
|
||||
return self._db_pool.runInteraction(inner_func, *args, **kwargs)
|
||||
|
||||
@ -113,7 +143,7 @@ class SQLBaseStore(object):
|
||||
else:
|
||||
return cursor.fetchall()
|
||||
|
||||
return self.runInteraction(interaction)
|
||||
return self.runInteraction("_execute", interaction)
|
||||
|
||||
def _execute_and_decode(self, query, *args):
|
||||
return self._execute(self.cursor_to_dict, query, *args)
|
||||
@ -130,6 +160,7 @@ class SQLBaseStore(object):
|
||||
or_replace : bool; if True performs an INSERT OR REPLACE
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_insert",
|
||||
self._simple_insert_txn, table, values, or_replace=or_replace,
|
||||
or_ignore=or_ignore,
|
||||
)
|
||||
@ -170,7 +201,6 @@ class SQLBaseStore(object):
|
||||
table, keyvalues, retcols=retcols, allow_none=allow_none
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _simple_select_one_onecol(self, table, keyvalues, retcol,
|
||||
allow_none=False):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
@ -181,19 +211,40 @@ class SQLBaseStore(object):
|
||||
keyvalues : dict of column names and values to select the row with
|
||||
retcol : string giving the name of the column to return
|
||||
"""
|
||||
ret = yield self._simple_select_one(
|
||||
return self.runInteraction(
|
||||
"_simple_select_one_onecol",
|
||||
self._simple_select_one_onecol_txn,
|
||||
table, keyvalues, retcol, allow_none=allow_none,
|
||||
)
|
||||
|
||||
def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
|
||||
allow_none=False):
|
||||
ret = self._simple_select_onecol_txn(
|
||||
txn,
|
||||
table=table,
|
||||
keyvalues=keyvalues,
|
||||
retcols=[retcol],
|
||||
allow_none=allow_none
|
||||
retcol=retcol,
|
||||
)
|
||||
|
||||
if ret:
|
||||
defer.returnValue(ret[retcol])
|
||||
return ret[0]
|
||||
else:
|
||||
defer.returnValue(None)
|
||||
if allow_none:
|
||||
return None
|
||||
else:
|
||||
raise StoreError(404, "No row found")
|
||||
|
||||
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
|
||||
sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
|
||||
"retcol": retcol,
|
||||
"table": table,
|
||||
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
|
||||
}
|
||||
|
||||
txn.execute(sql, keyvalues.values())
|
||||
|
||||
return [r[0] for r in txn.fetchall()]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _simple_select_onecol(self, table, keyvalues, retcol):
|
||||
"""Executes a SELECT query on the named table, which returns a list
|
||||
comprising of the values of the named column from the selected rows.
|
||||
@ -206,19 +257,11 @@ class SQLBaseStore(object):
|
||||
Returns:
|
||||
Deferred: Results in a list
|
||||
"""
|
||||
sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
|
||||
"retcol": retcol,
|
||||
"table": table,
|
||||
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
|
||||
}
|
||||
|
||||
def func(txn):
|
||||
txn.execute(sql, keyvalues.values())
|
||||
return txn.fetchall()
|
||||
|
||||
res = yield self.runInteraction(func)
|
||||
|
||||
defer.returnValue([r[0] for r in res])
|
||||
return self.runInteraction(
|
||||
"_simple_select_onecol",
|
||||
self._simple_select_onecol_txn,
|
||||
table, keyvalues, retcol
|
||||
)
|
||||
|
||||
def _simple_select_list(self, table, keyvalues, retcols):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
@ -229,17 +272,30 @@ class SQLBaseStore(object):
|
||||
keyvalues : dict of column names and values to select the rows with
|
||||
retcols : list of strings giving the names of the columns to return
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_select_list",
|
||||
self._simple_select_list_txn,
|
||||
table, keyvalues, retcols
|
||||
)
|
||||
|
||||
def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
Args:
|
||||
txn : Transaction object
|
||||
table : string giving the table name
|
||||
keyvalues : dict of column names and values to select the rows with
|
||||
retcols : list of strings giving the names of the columns to return
|
||||
"""
|
||||
sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
)
|
||||
|
||||
def func(txn):
|
||||
txn.execute(sql, keyvalues.values())
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
return self.runInteraction(func)
|
||||
txn.execute(sql, keyvalues.values())
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||
retcols=None):
|
||||
@ -307,7 +363,7 @@ class SQLBaseStore(object):
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
return ret
|
||||
return self.runInteraction(func)
|
||||
return self.runInteraction("_simple_selectupdate_one", func)
|
||||
|
||||
def _simple_delete_one(self, table, keyvalues):
|
||||
"""Executes a DELETE query on the named table, expecting to delete a
|
||||
@ -319,7 +375,7 @@ class SQLBaseStore(object):
|
||||
"""
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
)
|
||||
|
||||
def func(txn):
|
||||
@ -328,7 +384,25 @@ class SQLBaseStore(object):
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "more than one row matched")
|
||||
return self.runInteraction(func)
|
||||
return self.runInteraction("_simple_delete_one", func)
|
||||
|
||||
def _simple_delete(self, table, keyvalues):
|
||||
"""Executes a DELETE query on the named table.
|
||||
|
||||
Args:
|
||||
table : string giving the table name
|
||||
keyvalues : dict of column names and values to select the row with
|
||||
"""
|
||||
|
||||
return self.runInteraction("_simple_delete", self._simple_delete_txn)
|
||||
|
||||
def _simple_delete_txn(self, txn, table, keyvalues):
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
)
|
||||
|
||||
return txn.execute(sql, keyvalues.values())
|
||||
|
||||
def _simple_max_id(self, table):
|
||||
"""Executes a SELECT query on the named table, expecting to return the
|
||||
@ -346,7 +420,7 @@ class SQLBaseStore(object):
|
||||
return 0
|
||||
return max_id
|
||||
|
||||
return self.runInteraction(func)
|
||||
return self.runInteraction("_simple_max_id", func)
|
||||
|
||||
def _parse_event_from_row(self, row_dict):
|
||||
d = copy.deepcopy({k: v for k, v in row_dict.items()})
|
||||
@ -355,6 +429,10 @@ class SQLBaseStore(object):
|
||||
d.pop("topological_ordering", None)
|
||||
d.pop("processed", None)
|
||||
d["origin_server_ts"] = d.pop("ts", 0)
|
||||
replaces_state = d.pop("prev_state", None)
|
||||
|
||||
if replaces_state:
|
||||
d["replaces_state"] = replaces_state
|
||||
|
||||
d.update(json.loads(row_dict["unrecognized_keys"]))
|
||||
d["content"] = json.loads(d["content"])
|
||||
@ -369,23 +447,65 @@ class SQLBaseStore(object):
|
||||
**d
|
||||
)
|
||||
|
||||
def _get_events_txn(self, txn, event_ids):
|
||||
# FIXME (erikj): This should be batched?
|
||||
|
||||
sql = "SELECT * FROM events WHERE event_id = ?"
|
||||
|
||||
event_rows = []
|
||||
for e_id in event_ids:
|
||||
c = txn.execute(sql, (e_id,))
|
||||
event_rows.extend(self.cursor_to_dict(c))
|
||||
|
||||
return self._parse_events_txn(txn, event_rows)
|
||||
|
||||
def _parse_events(self, rows):
|
||||
return self.runInteraction(self._parse_events_txn, rows)
|
||||
return self.runInteraction(
|
||||
"_parse_events", self._parse_events_txn, rows
|
||||
)
|
||||
|
||||
def _parse_events_txn(self, txn, rows):
|
||||
events = [self._parse_event_from_row(r) for r in rows]
|
||||
|
||||
sql = "SELECT * FROM events WHERE event_id = ?"
|
||||
select_event_sql = "SELECT * FROM events WHERE event_id = ?"
|
||||
|
||||
for ev in events:
|
||||
if hasattr(ev, "prev_state"):
|
||||
# Load previous state_content.
|
||||
# TODO: Should we be pulling this out above?
|
||||
cursor = txn.execute(sql, (ev.prev_state,))
|
||||
prevs = self.cursor_to_dict(cursor)
|
||||
if prevs:
|
||||
prev = self._parse_event_from_row(prevs[0])
|
||||
ev.prev_content = prev.content
|
||||
for i, ev in enumerate(events):
|
||||
signatures = self._get_event_origin_signatures_txn(
|
||||
txn, ev.event_id,
|
||||
)
|
||||
|
||||
ev.signatures = {
|
||||
k: encode_base64(v) for k, v in signatures.items()
|
||||
}
|
||||
|
||||
prevs = self._get_prev_events_and_state(txn, ev.event_id)
|
||||
|
||||
ev.prev_events = [
|
||||
(e_id, h)
|
||||
for e_id, h, is_state in prevs
|
||||
if is_state == 0
|
||||
]
|
||||
|
||||
ev.auth_events = self._get_auth_events(txn, ev.event_id)
|
||||
|
||||
if hasattr(ev, "state_key"):
|
||||
ev.prev_state = [
|
||||
(e_id, h)
|
||||
for e_id, h, is_state in prevs
|
||||
if is_state == 1
|
||||
]
|
||||
|
||||
if hasattr(ev, "replaces_state"):
|
||||
# Load previous state_content.
|
||||
# FIXME (erikj): Handle multiple prev_states.
|
||||
cursor = txn.execute(
|
||||
select_event_sql,
|
||||
(ev.replaces_state,)
|
||||
)
|
||||
prevs = self.cursor_to_dict(cursor)
|
||||
if prevs:
|
||||
prev = self._parse_event_from_row(prevs[0])
|
||||
ev.prev_content = prev.content
|
||||
|
||||
if not hasattr(ev, "redacted"):
|
||||
logger.debug("Doesn't have redacted key: %s", ev)
|
||||
@ -393,15 +513,16 @@ class SQLBaseStore(object):
|
||||
|
||||
if ev.redacted:
|
||||
# Get the redaction event.
|
||||
sql = "SELECT * FROM events WHERE event_id = ?"
|
||||
txn.execute(sql, (ev.redacted,))
|
||||
select_event_sql = "SELECT * FROM events WHERE event_id = ?"
|
||||
txn.execute(select_event_sql, (ev.redacted,))
|
||||
|
||||
del_evs = self._parse_events_txn(
|
||||
txn, self.cursor_to_dict(txn)
|
||||
)
|
||||
|
||||
if del_evs:
|
||||
prune_event(ev)
|
||||
ev = prune_event(ev)
|
||||
events[i] = ev
|
||||
ev.redacted_because = del_evs[0]
|
||||
|
||||
return events
|
||||
|
@ -95,6 +95,7 @@ class DirectoryStore(SQLBaseStore):
|
||||
|
||||
def delete_room_alias(self, room_alias):
|
||||
return self.runInteraction(
|
||||
"delete_room_alias",
|
||||
self._delete_room_alias_txn,
|
||||
room_alias,
|
||||
)
|
||||
|
377
synapse/storage/event_federation.py
Normal file
377
synapse/storage/event_federation.py
Normal file
@ -0,0 +1,377 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from syutil.base64util import encode_base64
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventFederationStore(SQLBaseStore):
|
||||
|
||||
def get_auth_chain(self, event_id):
|
||||
return self.runInteraction(
|
||||
"get_auth_chain",
|
||||
self._get_auth_chain_txn,
|
||||
event_id
|
||||
)
|
||||
|
||||
def _get_auth_chain_txn(self, txn, event_id):
|
||||
results = self._get_auth_chain_ids_txn(txn, event_id)
|
||||
|
||||
sql = "SELECT * FROM events WHERE event_id = ?"
|
||||
rows = []
|
||||
for ev_id in results:
|
||||
c = txn.execute(sql, (ev_id,))
|
||||
rows.extend(self.cursor_to_dict(c))
|
||||
|
||||
return self._parse_events_txn(txn, rows)
|
||||
|
||||
def get_auth_chain_ids(self, event_id):
|
||||
return self.runInteraction(
|
||||
"get_auth_chain_ids",
|
||||
self._get_auth_chain_ids_txn,
|
||||
event_id
|
||||
)
|
||||
|
||||
def _get_auth_chain_ids_txn(self, txn, event_id):
|
||||
results = set()
|
||||
|
||||
base_sql = (
|
||||
"SELECT auth_id FROM event_auth WHERE %s"
|
||||
)
|
||||
|
||||
front = set([event_id])
|
||||
while front:
|
||||
sql = base_sql % (
|
||||
" OR ".join(["event_id=?"] * len(front)),
|
||||
)
|
||||
|
||||
txn.execute(sql, list(front))
|
||||
front = [r[0] for r in txn.fetchall()]
|
||||
results.update(front)
|
||||
|
||||
return list(results)
|
||||
|
||||
def get_oldest_events_in_room(self, room_id):
|
||||
return self.runInteraction(
|
||||
"get_oldest_events_in_room",
|
||||
self._get_oldest_events_in_room_txn,
|
||||
room_id,
|
||||
)
|
||||
|
||||
def _get_oldest_events_in_room_txn(self, txn, room_id):
|
||||
return self._simple_select_onecol_txn(
|
||||
txn,
|
||||
table="event_backward_extremities",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
},
|
||||
retcol="event_id",
|
||||
)
|
||||
|
||||
def get_latest_events_in_room(self, room_id):
|
||||
return self.runInteraction(
|
||||
"get_latest_events_in_room",
|
||||
self._get_latest_events_in_room,
|
||||
room_id,
|
||||
)
|
||||
|
||||
def _get_latest_events_in_room(self, txn, room_id):
|
||||
sql = (
|
||||
"SELECT e.event_id, e.depth FROM events as e "
|
||||
"INNER JOIN event_forward_extremities as f "
|
||||
"ON e.event_id = f.event_id "
|
||||
"WHERE f.room_id = ?"
|
||||
)
|
||||
|
||||
txn.execute(sql, (room_id, ))
|
||||
|
||||
results = []
|
||||
for event_id, depth in txn.fetchall():
|
||||
hashes = self._get_event_reference_hashes_txn(txn, event_id)
|
||||
prev_hashes = {
|
||||
k: encode_base64(v) for k, v in hashes.items()
|
||||
if k == "sha256"
|
||||
}
|
||||
results.append((event_id, prev_hashes, depth))
|
||||
|
||||
return results
|
||||
|
||||
def _get_latest_state_in_room(self, txn, room_id, type, state_key):
|
||||
event_ids = self._simple_select_onecol_txn(
|
||||
txn,
|
||||
table="state_forward_extremities",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"type": type,
|
||||
"state_key": state_key,
|
||||
},
|
||||
retcol="event_id",
|
||||
)
|
||||
|
||||
results = []
|
||||
for event_id in event_ids:
|
||||
hashes = self._get_event_reference_hashes_txn(txn, event_id)
|
||||
prev_hashes = {
|
||||
k: encode_base64(v) for k, v in hashes.items()
|
||||
if k == "sha256"
|
||||
}
|
||||
results.append((event_id, prev_hashes))
|
||||
|
||||
return results
|
||||
|
||||
def _get_prev_events(self, txn, event_id):
|
||||
results = self._get_prev_events_and_state(
|
||||
txn,
|
||||
event_id,
|
||||
is_state=0,
|
||||
)
|
||||
|
||||
return [(e_id, h, ) for e_id, h, _ in results]
|
||||
|
||||
def _get_prev_state(self, txn, event_id):
|
||||
results = self._get_prev_events_and_state(
|
||||
txn,
|
||||
event_id,
|
||||
is_state=1,
|
||||
)
|
||||
|
||||
return [(e_id, h, ) for e_id, h, _ in results]
|
||||
|
||||
def _get_prev_events_and_state(self, txn, event_id, is_state=None):
|
||||
keyvalues = {
|
||||
"event_id": event_id,
|
||||
}
|
||||
|
||||
if is_state is not None:
|
||||
keyvalues["is_state"] = is_state
|
||||
|
||||
res = self._simple_select_list_txn(
|
||||
txn,
|
||||
table="event_edges",
|
||||
keyvalues=keyvalues,
|
||||
retcols=["prev_event_id", "is_state"],
|
||||
)
|
||||
|
||||
results = []
|
||||
for d in res:
|
||||
hashes = self._get_event_reference_hashes_txn(
|
||||
txn,
|
||||
d["prev_event_id"]
|
||||
)
|
||||
prev_hashes = {
|
||||
k: encode_base64(v) for k, v in hashes.items()
|
||||
if k == "sha256"
|
||||
}
|
||||
results.append((d["prev_event_id"], prev_hashes, d["is_state"]))
|
||||
|
||||
return results
|
||||
|
||||
def _get_auth_events(self, txn, event_id):
|
||||
auth_ids = self._simple_select_onecol_txn(
|
||||
txn,
|
||||
table="event_auth",
|
||||
keyvalues={
|
||||
"event_id": event_id,
|
||||
},
|
||||
retcol="auth_id",
|
||||
)
|
||||
|
||||
results = []
|
||||
for auth_id in auth_ids:
|
||||
hashes = self._get_event_reference_hashes_txn(txn, auth_id)
|
||||
prev_hashes = {
|
||||
k: encode_base64(v) for k, v in hashes.items()
|
||||
if k == "sha256"
|
||||
}
|
||||
results.append((auth_id, prev_hashes))
|
||||
|
||||
return results
|
||||
|
||||
def get_min_depth(self, room_id):
|
||||
return self.runInteraction(
|
||||
"get_min_depth",
|
||||
self._get_min_depth_interaction,
|
||||
room_id,
|
||||
)
|
||||
|
||||
def _get_min_depth_interaction(self, txn, room_id):
|
||||
min_depth = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="room_depth",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="min_depth",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
return int(min_depth) if min_depth is not None else None
|
||||
|
||||
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
|
||||
min_depth = self._get_min_depth_interaction(txn, room_id)
|
||||
|
||||
do_insert = depth < min_depth if min_depth else True
|
||||
|
||||
if do_insert:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="room_depth",
|
||||
values={
|
||||
"room_id": room_id,
|
||||
"min_depth": depth,
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
def _handle_prev_events(self, txn, outlier, event_id, prev_events,
|
||||
room_id):
|
||||
for e_id, _ in prev_events:
|
||||
# TODO (erikj): This could be done as a bulk insert
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_edges",
|
||||
values={
|
||||
"event_id": event_id,
|
||||
"prev_event_id": e_id,
|
||||
"room_id": room_id,
|
||||
"is_state": 0,
|
||||
},
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
# Update the extremities table if this is not an outlier.
|
||||
if not outlier:
|
||||
for e_id, _ in prev_events:
|
||||
# TODO (erikj): This could be done as a bulk insert
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="event_forward_extremities",
|
||||
keyvalues={
|
||||
"event_id": e_id,
|
||||
"room_id": room_id,
|
||||
}
|
||||
)
|
||||
|
||||
# We only insert as a forward extremity the new pdu if there are
|
||||
# no other pdus that reference it as a prev pdu
|
||||
query = (
|
||||
"INSERT OR IGNORE INTO %(table)s (event_id, room_id) "
|
||||
"SELECT ?, ? WHERE NOT EXISTS ("
|
||||
"SELECT 1 FROM %(event_edges)s WHERE "
|
||||
"prev_event_id = ? "
|
||||
")"
|
||||
) % {
|
||||
"table": "event_forward_extremities",
|
||||
"event_edges": "event_edges",
|
||||
}
|
||||
|
||||
logger.debug("query: %s", query)
|
||||
|
||||
txn.execute(query, (event_id, room_id, event_id))
|
||||
|
||||
# Insert all the prev_pdus as a backwards thing, they'll get
|
||||
# deleted in a second if they're incorrect anyway.
|
||||
for e_id, _ in prev_events:
|
||||
# TODO (erikj): This could be done as a bulk insert
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_backward_extremities",
|
||||
values={
|
||||
"event_id": e_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
# Also delete from the backwards extremities table all ones that
|
||||
# reference pdus that we have already seen
|
||||
query = (
|
||||
"DELETE FROM event_backward_extremities WHERE EXISTS ("
|
||||
"SELECT 1 FROM events "
|
||||
"WHERE "
|
||||
"event_backward_extremities.event_id = events.event_id "
|
||||
"AND not events.outlier "
|
||||
")"
|
||||
)
|
||||
txn.execute(query)
|
||||
|
||||
def get_backfill_events(self, room_id, event_list, limit):
|
||||
"""Get a list of Events for a given topic that occured before (and
|
||||
including) the pdus in pdu_list. Return a list of max size `limit`.
|
||||
|
||||
Args:
|
||||
txn
|
||||
room_id (str)
|
||||
event_list (list)
|
||||
limit (int)
|
||||
|
||||
Return:
|
||||
list: A list of PduTuples
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_backfill_events",
|
||||
self._get_backfill_events, room_id, event_list, limit
|
||||
)
|
||||
|
||||
def _get_backfill_events(self, txn, room_id, event_list, limit):
|
||||
logger.debug(
|
||||
"_get_backfill_events: %s, %s, %s",
|
||||
room_id, repr(event_list), limit
|
||||
)
|
||||
|
||||
# We seed the pdu_results with the things from the pdu_list.
|
||||
event_results = event_list
|
||||
|
||||
front = event_list
|
||||
|
||||
query = (
|
||||
"SELECT prev_event_id FROM event_edges "
|
||||
"WHERE room_id = ? AND event_id = ? "
|
||||
"LIMIT ?"
|
||||
)
|
||||
|
||||
# We iterate through all event_ids in `front` to select their previous
|
||||
# events. These are dumped in `new_front`.
|
||||
# We continue until we reach the limit *or* new_front is empty (i.e.,
|
||||
# we've run out of things to select
|
||||
while front and len(event_results) < limit:
|
||||
|
||||
new_front = []
|
||||
for event_id in front:
|
||||
logger.debug(
|
||||
"_backfill_interaction: id=%s",
|
||||
event_id
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
query,
|
||||
(room_id, event_id, limit - len(event_results))
|
||||
)
|
||||
|
||||
for row in txn.fetchall():
|
||||
logger.debug(
|
||||
"_backfill_interaction: got id=%s",
|
||||
*row
|
||||
)
|
||||
new_front.append(row[0])
|
||||
|
||||
front = new_front
|
||||
event_results += new_front
|
||||
|
||||
# We also want to update the `prev_pdus` attributes before returning.
|
||||
return self._get_events_txn(txn, event_results)
|
@ -1,915 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import SQLBaseStore, Table, JoinHelper
|
||||
|
||||
from synapse.federation.units import Pdu
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PduStore(SQLBaseStore):
|
||||
"""A collection of queries for handling PDUs.
|
||||
"""
|
||||
|
||||
def get_pdu(self, pdu_id, origin):
|
||||
"""Given a pdu_id and origin, get a PDU.
|
||||
|
||||
Args:
|
||||
txn
|
||||
pdu_id (str)
|
||||
origin (str)
|
||||
|
||||
Returns:
|
||||
PduTuple: If the pdu does not exist in the database, returns None
|
||||
"""
|
||||
|
||||
return self.runInteraction(
|
||||
self._get_pdu_tuple, pdu_id, origin
|
||||
)
|
||||
|
||||
def _get_pdu_tuple(self, txn, pdu_id, origin):
|
||||
res = self._get_pdu_tuples(txn, [(pdu_id, origin)])
|
||||
return res[0] if res else None
|
||||
|
||||
def _get_pdu_tuples(self, txn, pdu_id_tuples):
|
||||
results = []
|
||||
for pdu_id, origin in pdu_id_tuples:
|
||||
txn.execute(
|
||||
PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"),
|
||||
(pdu_id, origin)
|
||||
)
|
||||
|
||||
edges = [
|
||||
(r.prev_pdu_id, r.prev_origin)
|
||||
for r in PduEdgesTable.decode_results(txn.fetchall())
|
||||
]
|
||||
|
||||
query = (
|
||||
"SELECT %(fields)s FROM %(pdus)s as p "
|
||||
"LEFT JOIN %(state)s as s "
|
||||
"ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
|
||||
"WHERE p.pdu_id = ? AND p.origin = ? "
|
||||
) % {
|
||||
"fields": _pdu_state_joiner.get_fields(
|
||||
PdusTable="p", StatePdusTable="s"),
|
||||
"pdus": PdusTable.table_name,
|
||||
"state": StatePdusTable.table_name,
|
||||
}
|
||||
|
||||
txn.execute(query, (pdu_id, origin))
|
||||
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
results.append(PduTuple(PduEntry(*row), edges))
|
||||
|
||||
return results
|
||||
|
||||
def get_current_state_for_context(self, context):
|
||||
"""Get a list of PDUs that represent the current state for a given
|
||||
context
|
||||
|
||||
Args:
|
||||
context (str)
|
||||
|
||||
Returns:
|
||||
list: A list of PduTuples
|
||||
"""
|
||||
|
||||
return self.runInteraction(
|
||||
self._get_current_state_for_context,
|
||||
context
|
||||
)
|
||||
|
||||
def _get_current_state_for_context(self, txn, context):
|
||||
query = (
|
||||
"SELECT pdu_id, origin FROM %s WHERE context = ?"
|
||||
% CurrentStateTable.table_name
|
||||
)
|
||||
|
||||
logger.debug("get_current_state %s, Args=%s", query, context)
|
||||
txn.execute(query, (context,))
|
||||
|
||||
res = txn.fetchall()
|
||||
|
||||
logger.debug("get_current_state %d results", len(res))
|
||||
|
||||
return self._get_pdu_tuples(txn, res)
|
||||
|
||||
def _persist_pdu_txn(self, txn, prev_pdus, cols):
|
||||
"""Inserts a (non-state) PDU into the database.
|
||||
|
||||
Args:
|
||||
txn,
|
||||
prev_pdus (list)
|
||||
**cols: The columns to insert into the PdusTable.
|
||||
"""
|
||||
entry = PdusTable.EntryType(
|
||||
**{k: cols.get(k, None) for k in PdusTable.fields}
|
||||
)
|
||||
|
||||
txn.execute(PdusTable.insert_statement(), entry)
|
||||
|
||||
self._handle_prev_pdus(
|
||||
txn, entry.outlier, entry.pdu_id, entry.origin,
|
||||
prev_pdus, entry.context
|
||||
)
|
||||
|
||||
def mark_pdu_as_processed(self, pdu_id, pdu_origin):
|
||||
"""Mark a received PDU as processed.
|
||||
|
||||
Args:
|
||||
txn
|
||||
pdu_id (str)
|
||||
pdu_origin (str)
|
||||
"""
|
||||
|
||||
return self.runInteraction(
|
||||
self._mark_as_processed, pdu_id, pdu_origin
|
||||
)
|
||||
|
||||
def _mark_as_processed(self, txn, pdu_id, pdu_origin):
|
||||
txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name)
|
||||
|
||||
def get_all_pdus_from_context(self, context):
|
||||
"""Get a list of all PDUs for a given context."""
|
||||
return self.runInteraction(
|
||||
self._get_all_pdus_from_context, context,
|
||||
)
|
||||
|
||||
def _get_all_pdus_from_context(self, txn, context):
|
||||
query = (
|
||||
"SELECT pdu_id, origin FROM %s "
|
||||
"WHERE context = ?"
|
||||
) % PdusTable.table_name
|
||||
|
||||
txn.execute(query, (context,))
|
||||
|
||||
return self._get_pdu_tuples(txn, txn.fetchall())
|
||||
|
||||
def get_backfill(self, context, pdu_list, limit):
|
||||
"""Get a list of Pdus for a given topic that occured before (and
|
||||
including) the pdus in pdu_list. Return a list of max size `limit`.
|
||||
|
||||
Args:
|
||||
txn
|
||||
context (str)
|
||||
pdu_list (list)
|
||||
limit (int)
|
||||
|
||||
Return:
|
||||
list: A list of PduTuples
|
||||
"""
|
||||
return self.runInteraction(
|
||||
self._get_backfill, context, pdu_list, limit
|
||||
)
|
||||
|
||||
def _get_backfill(self, txn, context, pdu_list, limit):
|
||||
logger.debug(
|
||||
"backfill: %s, %s, %s",
|
||||
context, repr(pdu_list), limit
|
||||
)
|
||||
|
||||
# We seed the pdu_results with the things from the pdu_list.
|
||||
pdu_results = pdu_list
|
||||
|
||||
front = pdu_list
|
||||
|
||||
query = (
|
||||
"SELECT prev_pdu_id, prev_origin FROM %(edges_table)s "
|
||||
"WHERE context = ? AND pdu_id = ? AND origin = ? "
|
||||
"LIMIT ?"
|
||||
) % {
|
||||
"edges_table": PduEdgesTable.table_name,
|
||||
}
|
||||
|
||||
# We iterate through all pdu_ids in `front` to select their previous
|
||||
# pdus. These are dumped in `new_front`. We continue until we reach the
|
||||
# limit *or* new_front is empty (i.e., we've run out of things to
|
||||
# select
|
||||
while front and len(pdu_results) < limit:
|
||||
|
||||
new_front = []
|
||||
for pdu_id, origin in front:
|
||||
logger.debug(
|
||||
"_backfill_interaction: i=%s, o=%s",
|
||||
pdu_id, origin
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
query,
|
||||
(context, pdu_id, origin, limit - len(pdu_results))
|
||||
)
|
||||
|
||||
for row in txn.fetchall():
|
||||
logger.debug(
|
||||
"_backfill_interaction: got i=%s, o=%s",
|
||||
*row
|
||||
)
|
||||
new_front.append(row)
|
||||
|
||||
front = new_front
|
||||
pdu_results += new_front
|
||||
|
||||
# We also want to update the `prev_pdus` attributes before returning.
|
||||
return self._get_pdu_tuples(txn, pdu_results)
|
||||
|
||||
def get_min_depth_for_context(self, context):
|
||||
"""Get the current minimum depth for a context
|
||||
|
||||
Args:
|
||||
txn
|
||||
context (str)
|
||||
"""
|
||||
return self.runInteraction(
|
||||
self._get_min_depth_for_context, context
|
||||
)
|
||||
|
||||
def _get_min_depth_for_context(self, txn, context):
|
||||
return self._get_min_depth_interaction(txn, context)
|
||||
|
||||
def _get_min_depth_interaction(self, txn, context):
|
||||
txn.execute(
|
||||
"SELECT min_depth FROM %s WHERE context = ?"
|
||||
% ContextDepthTable.table_name,
|
||||
(context,)
|
||||
)
|
||||
|
||||
row = txn.fetchone()
|
||||
|
||||
return row[0] if row else None
|
||||
|
||||
def _update_min_depth_for_context_txn(self, txn, context, depth):
|
||||
"""Update the minimum `depth` of the given context, which is the line
|
||||
on which we stop backfilling backwards.
|
||||
|
||||
Args:
|
||||
context (str)
|
||||
depth (int)
|
||||
"""
|
||||
min_depth = self._get_min_depth_interaction(txn, context)
|
||||
|
||||
do_insert = depth < min_depth if min_depth else True
|
||||
|
||||
if do_insert:
|
||||
txn.execute(
|
||||
"INSERT OR REPLACE INTO %s (context, min_depth) "
|
||||
"VALUES (?,?)" % ContextDepthTable.table_name,
|
||||
(context, depth)
|
||||
)
|
||||
|
||||
def _get_latest_pdus_in_context(self, txn, context):
|
||||
"""Get's a list of the most current pdus for a given context. This is
|
||||
used when we are sending a Pdu and need to fill out the `prev_pdus`
|
||||
key
|
||||
|
||||
Args:
|
||||
txn
|
||||
context
|
||||
"""
|
||||
query = (
|
||||
"SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p "
|
||||
"INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id "
|
||||
"AND f.origin = p.origin "
|
||||
"WHERE f.context = ?"
|
||||
) % {
|
||||
"pdus": PdusTable.table_name,
|
||||
"forward": PduForwardExtremitiesTable.table_name,
|
||||
}
|
||||
|
||||
logger.debug("get_prev query: %s", query)
|
||||
|
||||
txn.execute(
|
||||
query,
|
||||
(context, )
|
||||
)
|
||||
|
||||
results = txn.fetchall()
|
||||
|
||||
return [(row[0], row[1], row[2]) for row in results]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_oldest_pdus_in_context(self, context):
|
||||
"""Get a list of Pdus that we haven't backfilled beyond yet (and havent
|
||||
seen). This list is used when we want to backfill backwards and is the
|
||||
list we send to the remote server.
|
||||
|
||||
Args:
|
||||
txn
|
||||
context (str)
|
||||
|
||||
Returns:
|
||||
list: A list of PduIdTuple.
|
||||
"""
|
||||
results = yield self._execute(
|
||||
None,
|
||||
"SELECT pdu_id, origin FROM %(back)s WHERE context = ?"
|
||||
% {"back": PduBackwardExtremitiesTable.table_name, },
|
||||
context
|
||||
)
|
||||
|
||||
defer.returnValue([PduIdTuple(i, o) for i, o in results])
|
||||
|
||||
def is_pdu_new(self, pdu_id, origin, context, depth):
|
||||
"""For a given Pdu, try and figure out if it's 'new', i.e., if it's
|
||||
not something we got randomly from the past, for example when we
|
||||
request the current state of the room that will probably return a bunch
|
||||
of pdus from before we joined.
|
||||
|
||||
Args:
|
||||
txn
|
||||
pdu_id (str)
|
||||
origin (str)
|
||||
context (str)
|
||||
depth (int)
|
||||
|
||||
Returns:
|
||||
bool
|
||||
"""
|
||||
|
||||
return self.runInteraction(
|
||||
self._is_pdu_new,
|
||||
pdu_id=pdu_id,
|
||||
origin=origin,
|
||||
context=context,
|
||||
depth=depth
|
||||
)
|
||||
|
||||
def _is_pdu_new(self, txn, pdu_id, origin, context, depth):
|
||||
# If depth > min depth in back table, then we classify it as new.
|
||||
# OR if there is nothing in the back table, then it kinda needs to
|
||||
# be a new thing.
|
||||
query = (
|
||||
"SELECT min(p.depth) FROM %(edges)s as e "
|
||||
"INNER JOIN %(back)s as b "
|
||||
"ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin "
|
||||
"INNER JOIN %(pdus)s as p "
|
||||
"ON e.pdu_id = p.pdu_id AND p.origin = e.origin "
|
||||
"WHERE p.context = ?"
|
||||
) % {
|
||||
"pdus": PdusTable.table_name,
|
||||
"edges": PduEdgesTable.table_name,
|
||||
"back": PduBackwardExtremitiesTable.table_name,
|
||||
}
|
||||
|
||||
txn.execute(query, (context,))
|
||||
|
||||
min_depth, = txn.fetchone()
|
||||
|
||||
if not min_depth or depth > int(min_depth):
|
||||
logger.debug(
|
||||
"is_new true: id=%s, o=%s, d=%s min_depth=%s",
|
||||
pdu_id, origin, depth, min_depth
|
||||
)
|
||||
return True
|
||||
|
||||
# If this pdu is in the forwards table, then it also is a new one
|
||||
query = (
|
||||
"SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?"
|
||||
) % {
|
||||
"forward": PduForwardExtremitiesTable.table_name,
|
||||
}
|
||||
|
||||
txn.execute(query, (pdu_id, origin))
|
||||
|
||||
# Did we get anything?
|
||||
if txn.fetchall():
|
||||
logger.debug(
|
||||
"is_new true: id=%s, o=%s, d=%s was forward",
|
||||
pdu_id, origin, depth
|
||||
)
|
||||
return True
|
||||
|
||||
logger.debug(
|
||||
"is_new false: id=%s, o=%s, d=%s",
|
||||
pdu_id, origin, depth
|
||||
)
|
||||
|
||||
# FINE THEN. It's probably old.
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
@log_function
|
||||
def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus,
|
||||
context):
|
||||
txn.executemany(
|
||||
PduEdgesTable.insert_statement(),
|
||||
[(pdu_id, origin, p[0], p[1], context) for p in prev_pdus]
|
||||
)
|
||||
|
||||
# Update the extremities table if this is not an outlier.
|
||||
if not outlier:
|
||||
|
||||
# First, we delete the new one from the forwards extremities table.
|
||||
query = (
|
||||
"DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
|
||||
% PduForwardExtremitiesTable.table_name
|
||||
)
|
||||
txn.executemany(query, prev_pdus)
|
||||
|
||||
# We only insert as a forward extremety the new pdu if there are no
|
||||
# other pdus that reference it as a prev pdu
|
||||
query = (
|
||||
"INSERT INTO %(table)s (pdu_id, origin, context) "
|
||||
"SELECT ?, ?, ? WHERE NOT EXISTS ("
|
||||
"SELECT 1 FROM %(pdu_edges)s WHERE "
|
||||
"prev_pdu_id = ? AND prev_origin = ?"
|
||||
")"
|
||||
) % {
|
||||
"table": PduForwardExtremitiesTable.table_name,
|
||||
"pdu_edges": PduEdgesTable.table_name
|
||||
}
|
||||
|
||||
logger.debug("query: %s", query)
|
||||
|
||||
txn.execute(query, (pdu_id, origin, context, pdu_id, origin))
|
||||
|
||||
# Insert all the prev_pdus as a backwards thing, they'll get
|
||||
# deleted in a second if they're incorrect anyway.
|
||||
txn.executemany(
|
||||
PduBackwardExtremitiesTable.insert_statement(),
|
||||
[(i, o, context) for i, o in prev_pdus]
|
||||
)
|
||||
|
||||
# Also delete from the backwards extremities table all ones that
|
||||
# reference pdus that we have already seen
|
||||
query = (
|
||||
"DELETE FROM %(pdu_back)s WHERE EXISTS ("
|
||||
"SELECT 1 FROM %(pdus)s AS pdus "
|
||||
"WHERE "
|
||||
"%(pdu_back)s.pdu_id = pdus.pdu_id "
|
||||
"AND %(pdu_back)s.origin = pdus.origin "
|
||||
"AND not pdus.outlier "
|
||||
")"
|
||||
) % {
|
||||
"pdu_back": PduBackwardExtremitiesTable.table_name,
|
||||
"pdus": PdusTable.table_name,
|
||||
}
|
||||
txn.execute(query)
|
||||
|
||||
|
||||
class StatePduStore(SQLBaseStore):
|
||||
"""A collection of queries for handling state PDUs.
|
||||
"""
|
||||
|
||||
def _persist_state_txn(self, txn, prev_pdus, cols):
|
||||
"""Inserts a state PDU into the database
|
||||
|
||||
Args:
|
||||
txn,
|
||||
prev_pdus (list)
|
||||
**cols: The columns to insert into the PdusTable and StatePdusTable
|
||||
"""
|
||||
pdu_entry = PdusTable.EntryType(
|
||||
**{k: cols.get(k, None) for k in PdusTable.fields}
|
||||
)
|
||||
state_entry = StatePdusTable.EntryType(
|
||||
**{k: cols.get(k, None) for k in StatePdusTable.fields}
|
||||
)
|
||||
|
||||
logger.debug("Inserting pdu: %s", repr(pdu_entry))
|
||||
logger.debug("Inserting state: %s", repr(state_entry))
|
||||
|
||||
txn.execute(PdusTable.insert_statement(), pdu_entry)
|
||||
txn.execute(StatePdusTable.insert_statement(), state_entry)
|
||||
|
||||
self._handle_prev_pdus(
|
||||
txn,
|
||||
pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus,
|
||||
pdu_entry.context
|
||||
)
|
||||
|
||||
def get_unresolved_state_tree(self, new_state_pdu):
|
||||
return self.runInteraction(
|
||||
self._get_unresolved_state_tree, new_state_pdu
|
||||
)
|
||||
|
||||
@log_function
|
||||
def _get_unresolved_state_tree(self, txn, new_pdu):
|
||||
current = self._get_current_interaction(
|
||||
txn,
|
||||
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
|
||||
)
|
||||
|
||||
ReturnType = namedtuple(
|
||||
"StateReturnType", ["new_branch", "current_branch"]
|
||||
)
|
||||
return_value = ReturnType([new_pdu], [])
|
||||
|
||||
if not current:
|
||||
logger.debug("get_unresolved_state_tree No current state.")
|
||||
return (return_value, None)
|
||||
|
||||
return_value.current_branch.append(current)
|
||||
|
||||
enum_branches = self._enumerate_state_branches(
|
||||
txn, new_pdu, current
|
||||
)
|
||||
|
||||
missing_branch = None
|
||||
for branch, prev_state, state in enum_branches:
|
||||
if state:
|
||||
return_value[branch].append(state)
|
||||
else:
|
||||
# We don't have prev_state :(
|
||||
missing_branch = branch
|
||||
break
|
||||
|
||||
return (return_value, missing_branch)
|
||||
|
||||
def update_current_state(self, pdu_id, origin, context, pdu_type,
|
||||
state_key):
|
||||
return self.runInteraction(
|
||||
self._update_current_state,
|
||||
pdu_id, origin, context, pdu_type, state_key
|
||||
)
|
||||
|
||||
def _update_current_state(self, txn, pdu_id, origin, context, pdu_type,
|
||||
state_key):
|
||||
query = (
|
||||
"INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
|
||||
) % {
|
||||
"curr": CurrentStateTable.table_name,
|
||||
"fields": CurrentStateTable.get_fields_string(),
|
||||
"qs": ", ".join(["?"] * len(CurrentStateTable.fields))
|
||||
}
|
||||
|
||||
query_args = CurrentStateTable.EntryType(
|
||||
pdu_id=pdu_id,
|
||||
origin=origin,
|
||||
context=context,
|
||||
pdu_type=pdu_type,
|
||||
state_key=state_key
|
||||
)
|
||||
|
||||
txn.execute(query, query_args)
|
||||
|
||||
def get_current_state_pdu(self, context, pdu_type, state_key):
|
||||
"""For a given context, pdu_type, state_key 3-tuple, return what is
|
||||
currently considered the current state.
|
||||
|
||||
Args:
|
||||
txn
|
||||
context (str)
|
||||
pdu_type (str)
|
||||
state_key (str)
|
||||
|
||||
Returns:
|
||||
PduEntry
|
||||
"""
|
||||
|
||||
return self.runInteraction(
|
||||
self._get_current_state_pdu, context, pdu_type, state_key
|
||||
)
|
||||
|
||||
def _get_current_state_pdu(self, txn, context, pdu_type, state_key):
|
||||
return self._get_current_interaction(txn, context, pdu_type, state_key)
|
||||
|
||||
def _get_current_interaction(self, txn, context, pdu_type, state_key):
|
||||
logger.debug(
|
||||
"_get_current_interaction %s %s %s",
|
||||
context, pdu_type, state_key
|
||||
)
|
||||
|
||||
fields = _pdu_state_joiner.get_fields(
|
||||
PdusTable="p", StatePdusTable="s")
|
||||
|
||||
current_query = (
|
||||
"SELECT %(fields)s FROM %(state)s as s "
|
||||
"INNER JOIN %(pdus)s as p "
|
||||
"ON s.pdu_id = p.pdu_id AND s.origin = p.origin "
|
||||
"INNER JOIN %(curr)s as c "
|
||||
"ON s.pdu_id = c.pdu_id AND s.origin = c.origin "
|
||||
"WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? "
|
||||
) % {
|
||||
"fields": fields,
|
||||
"curr": CurrentStateTable.table_name,
|
||||
"state": StatePdusTable.table_name,
|
||||
"pdus": PdusTable.table_name,
|
||||
}
|
||||
|
||||
txn.execute(
|
||||
current_query,
|
||||
(context, pdu_type, state_key)
|
||||
)
|
||||
|
||||
row = txn.fetchone()
|
||||
|
||||
result = PduEntry(*row) if row else None
|
||||
|
||||
if not result:
|
||||
logger.debug("_get_current_interaction not found")
|
||||
else:
|
||||
logger.debug(
|
||||
"_get_current_interaction found %s %s",
|
||||
result.pdu_id, result.origin
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def handle_new_state(self, new_pdu):
|
||||
"""Actually perform conflict resolution on the new_pdu on the
|
||||
assumption we have all the pdus required to perform it.
|
||||
|
||||
Args:
|
||||
new_pdu
|
||||
|
||||
Returns:
|
||||
bool: True if the new_pdu clobbered the current state, False if not
|
||||
"""
|
||||
return self.runInteraction(
|
||||
self._handle_new_state, new_pdu
|
||||
)
|
||||
|
||||
def _handle_new_state(self, txn, new_pdu):
|
||||
logger.debug(
|
||||
"handle_new_state %s %s",
|
||||
new_pdu.pdu_id, new_pdu.origin
|
||||
)
|
||||
|
||||
current = self._get_current_interaction(
|
||||
txn,
|
||||
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
|
||||
)
|
||||
|
||||
is_current = False
|
||||
|
||||
if (not current or not current.prev_state_id
|
||||
or not current.prev_state_origin):
|
||||
# Oh, we don't have any state for this yet.
|
||||
is_current = True
|
||||
elif (current.pdu_id == new_pdu.prev_state_id
|
||||
and current.origin == new_pdu.prev_state_origin):
|
||||
# Oh! A direct clobber. Just do it.
|
||||
is_current = True
|
||||
else:
|
||||
##
|
||||
# Ok, now loop through until we get to a common ancestor.
|
||||
max_new = int(new_pdu.power_level)
|
||||
max_current = int(current.power_level)
|
||||
|
||||
enum_branches = self._enumerate_state_branches(
|
||||
txn, new_pdu, current
|
||||
)
|
||||
for branch, prev_state, state in enum_branches:
|
||||
if not state:
|
||||
raise RuntimeError(
|
||||
"Could not find state_pdu %s %s" %
|
||||
(
|
||||
prev_state.prev_state_id,
|
||||
prev_state.prev_state_origin
|
||||
)
|
||||
)
|
||||
|
||||
if branch == 0:
|
||||
max_new = max(int(state.depth), max_new)
|
||||
else:
|
||||
max_current = max(int(state.depth), max_current)
|
||||
|
||||
is_current = max_new > max_current
|
||||
|
||||
if is_current:
|
||||
logger.debug("handle_new_state make current")
|
||||
|
||||
# Right, this is a new thing, so woo, just insert it.
|
||||
txn.execute(
|
||||
"INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
|
||||
% {
|
||||
"curr": CurrentStateTable.table_name,
|
||||
"fields": CurrentStateTable.get_fields_string(),
|
||||
"qs": ", ".join(["?"] * len(CurrentStateTable.fields))
|
||||
},
|
||||
CurrentStateTable.EntryType(
|
||||
*(new_pdu.__dict__[k] for k in CurrentStateTable.fields)
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.debug("handle_new_state not current")
|
||||
|
||||
logger.debug("handle_new_state done")
|
||||
|
||||
return is_current
|
||||
|
||||
@log_function
|
||||
def _enumerate_state_branches(self, txn, pdu_a, pdu_b):
|
||||
branch_a = pdu_a
|
||||
branch_b = pdu_b
|
||||
|
||||
while True:
|
||||
if (branch_a.pdu_id == branch_b.pdu_id
|
||||
and branch_a.origin == branch_b.origin):
|
||||
# Woo! We found a common ancestor
|
||||
logger.debug("_enumerate_state_branches Found common ancestor")
|
||||
break
|
||||
|
||||
do_branch_a = (
|
||||
hasattr(branch_a, "prev_state_id") and
|
||||
branch_a.prev_state_id
|
||||
)
|
||||
|
||||
do_branch_b = (
|
||||
hasattr(branch_b, "prev_state_id") and
|
||||
branch_b.prev_state_id
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"do_branch_a=%s, do_branch_b=%s",
|
||||
do_branch_a, do_branch_b
|
||||
)
|
||||
|
||||
if do_branch_a and do_branch_b:
|
||||
do_branch_a = int(branch_a.depth) > int(branch_b.depth)
|
||||
|
||||
if do_branch_a:
|
||||
pdu_tuple = PduIdTuple(
|
||||
branch_a.prev_state_id,
|
||||
branch_a.prev_state_origin
|
||||
)
|
||||
|
||||
prev_branch = branch_a
|
||||
|
||||
logger.debug("getting branch_a prev %s", pdu_tuple)
|
||||
branch_a = self._get_pdu_tuple(txn, *pdu_tuple)
|
||||
if branch_a:
|
||||
branch_a = Pdu.from_pdu_tuple(branch_a)
|
||||
|
||||
logger.debug("branch_a=%s", branch_a)
|
||||
|
||||
yield (0, prev_branch, branch_a)
|
||||
|
||||
if not branch_a:
|
||||
break
|
||||
elif do_branch_b:
|
||||
pdu_tuple = PduIdTuple(
|
||||
branch_b.prev_state_id,
|
||||
branch_b.prev_state_origin
|
||||
)
|
||||
|
||||
prev_branch = branch_b
|
||||
|
||||
logger.debug("getting branch_b prev %s", pdu_tuple)
|
||||
branch_b = self._get_pdu_tuple(txn, *pdu_tuple)
|
||||
if branch_b:
|
||||
branch_b = Pdu.from_pdu_tuple(branch_b)
|
||||
|
||||
logger.debug("branch_b=%s", branch_b)
|
||||
|
||||
yield (1, prev_branch, branch_b)
|
||||
|
||||
if not branch_b:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
class PdusTable(Table):
|
||||
table_name = "pdus"
|
||||
|
||||
fields = [
|
||||
"pdu_id",
|
||||
"origin",
|
||||
"context",
|
||||
"pdu_type",
|
||||
"ts",
|
||||
"depth",
|
||||
"is_state",
|
||||
"content_json",
|
||||
"unrecognized_keys",
|
||||
"outlier",
|
||||
"have_processed",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("PdusEntry", fields)
|
||||
|
||||
|
||||
class PduDestinationsTable(Table):
|
||||
table_name = "pdu_destinations"
|
||||
|
||||
fields = [
|
||||
"pdu_id",
|
||||
"origin",
|
||||
"destination",
|
||||
"delivered_ts",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("PduDestinationsEntry", fields)
|
||||
|
||||
|
||||
class PduEdgesTable(Table):
|
||||
table_name = "pdu_edges"
|
||||
|
||||
fields = [
|
||||
"pdu_id",
|
||||
"origin",
|
||||
"prev_pdu_id",
|
||||
"prev_origin",
|
||||
"context"
|
||||
]
|
||||
|
||||
EntryType = namedtuple("PduEdgesEntry", fields)
|
||||
|
||||
|
||||
class PduForwardExtremitiesTable(Table):
|
||||
table_name = "pdu_forward_extremities"
|
||||
|
||||
fields = [
|
||||
"pdu_id",
|
||||
"origin",
|
||||
"context",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("PduForwardExtremitiesEntry", fields)
|
||||
|
||||
|
||||
class PduBackwardExtremitiesTable(Table):
|
||||
table_name = "pdu_backward_extremities"
|
||||
|
||||
fields = [
|
||||
"pdu_id",
|
||||
"origin",
|
||||
"context",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("PduBackwardExtremitiesEntry", fields)
|
||||
|
||||
|
||||
class ContextDepthTable(Table):
|
||||
table_name = "context_depth"
|
||||
|
||||
fields = [
|
||||
"context",
|
||||
"min_depth",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("ContextDepthEntry", fields)
|
||||
|
||||
|
||||
class StatePdusTable(Table):
|
||||
table_name = "state_pdus"
|
||||
|
||||
fields = [
|
||||
"pdu_id",
|
||||
"origin",
|
||||
"context",
|
||||
"pdu_type",
|
||||
"state_key",
|
||||
"power_level",
|
||||
"prev_state_id",
|
||||
"prev_state_origin",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("StatePdusEntry", fields)
|
||||
|
||||
|
||||
class CurrentStateTable(Table):
|
||||
table_name = "current_state"
|
||||
|
||||
fields = [
|
||||
"pdu_id",
|
||||
"origin",
|
||||
"context",
|
||||
"pdu_type",
|
||||
"state_key",
|
||||
]
|
||||
|
||||
EntryType = namedtuple("CurrentStateEntry", fields)
|
||||
|
||||
_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable)
|
||||
|
||||
|
||||
# TODO: These should probably be put somewhere more sensible
|
||||
PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin"))
|
||||
|
||||
PduEntry = _pdu_state_joiner.EntryType
|
||||
""" We are always interested in the join of the PdusTable and StatePdusTable,
|
||||
rather than just the PdusTable.
|
||||
|
||||
This does not include a prev_pdus key.
|
||||
"""
|
||||
|
||||
PduTuple = namedtuple(
|
||||
"PduTuple",
|
||||
("pdu_entry", "prev_pdu_list")
|
||||
)
|
||||
""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
|
||||
the `prev_pdus` key of a PDU.
|
||||
"""
|
@ -62,8 +62,10 @@ class RegistrationStore(SQLBaseStore):
|
||||
Raises:
|
||||
StoreError if the user_id could not be registered.
|
||||
"""
|
||||
yield self.runInteraction(self._register, user_id, token,
|
||||
password_hash)
|
||||
yield self.runInteraction(
|
||||
"register",
|
||||
self._register, user_id, token, password_hash
|
||||
)
|
||||
|
||||
def _register(self, txn, user_id, token, password_hash):
|
||||
now = int(self.clock.time())
|
||||
@ -100,17 +102,22 @@ class RegistrationStore(SQLBaseStore):
|
||||
StoreError if no user was found.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_user_by_token",
|
||||
self._query_for_auth,
|
||||
token
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_server_admin(self, user):
|
||||
return self._simple_select_one_onecol(
|
||||
res = yield self._simple_select_one_onecol(
|
||||
table="users",
|
||||
keyvalues={"name": user.to_string()},
|
||||
retcol="admin",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
defer.returnValue(res if res else False)
|
||||
|
||||
def _query_for_auth(self, txn, token):
|
||||
sql = (
|
||||
"SELECT users.name, users.admin, access_tokens.device_id "
|
||||
|
@ -132,209 +132,29 @@ class RoomStore(SQLBaseStore):
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_join_rule(self, room_id):
|
||||
sql = (
|
||||
"SELECT join_rule FROM room_join_rules as r "
|
||||
"INNER JOIN current_state_events as c "
|
||||
"ON r.event_id = c.event_id "
|
||||
"WHERE c.room_id = ? "
|
||||
)
|
||||
|
||||
rows = yield self._execute(None, sql, room_id)
|
||||
|
||||
if len(rows) == 1:
|
||||
defer.returnValue(rows[0][0])
|
||||
else:
|
||||
defer.returnValue(None)
|
||||
|
||||
def get_power_level(self, room_id, user_id):
|
||||
return self.runInteraction(
|
||||
self._get_power_level,
|
||||
room_id, user_id,
|
||||
)
|
||||
|
||||
def _get_power_level(self, txn, room_id, user_id):
|
||||
sql = (
|
||||
"SELECT level FROM room_power_levels as r "
|
||||
"INNER JOIN current_state_events as c "
|
||||
"ON r.event_id = c.event_id "
|
||||
"WHERE c.room_id = ? AND r.user_id = ? "
|
||||
)
|
||||
|
||||
rows = txn.execute(sql, (room_id, user_id,)).fetchall()
|
||||
|
||||
if len(rows) == 1:
|
||||
return rows[0][0]
|
||||
|
||||
sql = (
|
||||
"SELECT level FROM room_default_levels as r "
|
||||
"INNER JOIN current_state_events as c "
|
||||
"ON r.event_id = c.event_id "
|
||||
"WHERE c.room_id = ? "
|
||||
)
|
||||
|
||||
rows = txn.execute(sql, (room_id,)).fetchall()
|
||||
|
||||
if len(rows) == 1:
|
||||
return rows[0][0]
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_ops_levels(self, room_id):
|
||||
return self.runInteraction(
|
||||
self._get_ops_levels,
|
||||
room_id,
|
||||
)
|
||||
|
||||
def _get_ops_levels(self, txn, room_id):
|
||||
sql = (
|
||||
"SELECT ban_level, kick_level, redact_level "
|
||||
"FROM room_ops_levels as r "
|
||||
"INNER JOIN current_state_events as c "
|
||||
"ON r.event_id = c.event_id "
|
||||
"WHERE c.room_id = ? "
|
||||
)
|
||||
|
||||
rows = txn.execute(sql, (room_id,)).fetchall()
|
||||
|
||||
if len(rows) == 1:
|
||||
return OpsLevel(rows[0][0], rows[0][1], rows[0][2])
|
||||
else:
|
||||
return OpsLevel(None, None)
|
||||
|
||||
def get_add_state_level(self, room_id):
|
||||
return self._get_level_from_table("room_add_state_levels", room_id)
|
||||
|
||||
def get_send_event_level(self, room_id):
|
||||
return self._get_level_from_table("room_send_event_levels", room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_level_from_table(self, table, room_id):
|
||||
sql = (
|
||||
"SELECT level FROM %(table)s as r "
|
||||
"INNER JOIN current_state_events as c "
|
||||
"ON r.event_id = c.event_id "
|
||||
"WHERE c.room_id = ? "
|
||||
) % {"table": table}
|
||||
|
||||
rows = yield self._execute(None, sql, room_id)
|
||||
|
||||
if len(rows) == 1:
|
||||
defer.returnValue(rows[0][0])
|
||||
else:
|
||||
defer.returnValue(None)
|
||||
|
||||
def _store_room_topic_txn(self, txn, event):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"topics",
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"topic": event.topic,
|
||||
}
|
||||
)
|
||||
if hasattr(event, "topic"):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"topics",
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"topic": event.topic,
|
||||
}
|
||||
)
|
||||
|
||||
def _store_room_name_txn(self, txn, event):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"room_names",
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"name": event.name,
|
||||
}
|
||||
)
|
||||
|
||||
def _store_join_rule(self, txn, event):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"room_join_rules",
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"join_rule": event.content["join_rule"],
|
||||
},
|
||||
)
|
||||
|
||||
def _store_power_levels(self, txn, event):
|
||||
for user_id, level in event.content.items():
|
||||
if user_id == "default":
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"room_default_levels",
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"level": level,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"room_power_levels",
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"user_id": user_id,
|
||||
"level": level
|
||||
},
|
||||
)
|
||||
|
||||
def _store_default_level(self, txn, event):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"room_default_levels",
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"level": event.content["default_level"],
|
||||
},
|
||||
)
|
||||
|
||||
def _store_add_state_level(self, txn, event):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"room_add_state_levels",
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"level": event.content["level"],
|
||||
},
|
||||
)
|
||||
|
||||
def _store_send_event_level(self, txn, event):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"room_send_event_levels",
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"level": event.content["level"],
|
||||
},
|
||||
)
|
||||
|
||||
def _store_ops_level(self, txn, event):
|
||||
content = {
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
}
|
||||
|
||||
if "kick_level" in event.content:
|
||||
content["kick_level"] = event.content["kick_level"]
|
||||
|
||||
if "ban_level" in event.content:
|
||||
content["ban_level"] = event.content["ban_level"]
|
||||
|
||||
if "redact_level" in event.content:
|
||||
content["redact_level"] = event.content["redact_level"]
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"room_ops_levels",
|
||||
content,
|
||||
)
|
||||
if hasattr(event, "name"):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"room_names",
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"name": event.name,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class RoomsTable(Table):
|
||||
|
@ -1,31 +0,0 @@
|
||||
/* Copyright 2014 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
CREATE TABLE IF NOT EXISTS context_edge_pdus(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
context TEXT,
|
||||
CONSTRAINT context_edge_pdu_id_origin UNIQUE (pdu_id, origin)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS origin_edge_pdus(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
CONSTRAINT origin_edge_pdu_id_origin UNIQUE (pdu_id, origin)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS context_edge_pdu_id ON context_edge_pdus(pdu_id, origin);
|
||||
CREATE INDEX IF NOT EXISTS origin_edge_pdu_id ON origin_edge_pdus(pdu_id, origin);
|
75
synapse/storage/schema/event_edges.sql
Normal file
75
synapse/storage/schema/event_edges.sql
Normal file
@ -0,0 +1,75 @@
|
||||
|
||||
CREATE TABLE IF NOT EXISTS event_forward_extremities(
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id);
|
||||
CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS event_backward_extremities(
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id);
|
||||
CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS event_edges(
|
||||
event_id TEXT NOT NULL,
|
||||
prev_event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
is_state INTEGER NOT NULL,
|
||||
CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id, is_state)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id);
|
||||
CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS room_depth(
|
||||
room_id TEXT NOT NULL,
|
||||
min_depth INTEGER NOT NULL,
|
||||
CONSTRAINT uniqueness UNIQUE (room_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id);
|
||||
|
||||
|
||||
create TABLE IF NOT EXISTS event_destinations(
|
||||
event_id TEXT NOT NULL,
|
||||
destination TEXT NOT NULL,
|
||||
delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
|
||||
CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS state_forward_extremities(
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
state_key TEXT NOT NULL,
|
||||
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS st_extrem_keys ON state_forward_extremities(
|
||||
room_id, type, state_key
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS st_extrem_id ON state_forward_extremities(event_id);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS event_auth(
|
||||
event_id TEXT NOT NULL,
|
||||
auth_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
CONSTRAINT uniqueness UNIQUE (event_id, auth_id, room_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS evauth_edges_id ON event_auth(event_id);
|
||||
CREATE INDEX IF NOT EXISTS evauth_edges_auth_id ON event_auth(auth_id);
|
65
synapse/storage/schema/event_signatures.sql
Normal file
65
synapse/storage/schema/event_signatures.sql
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2014 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS event_content_hashes (
|
||||
event_id TEXT,
|
||||
algorithm TEXT,
|
||||
hash BLOB,
|
||||
CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes(
|
||||
event_id
|
||||
);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS event_reference_hashes (
|
||||
event_id TEXT,
|
||||
algorithm TEXT,
|
||||
hash BLOB,
|
||||
CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes (
|
||||
event_id
|
||||
);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS event_origin_signatures (
|
||||
event_id TEXT,
|
||||
origin TEXT,
|
||||
key_id TEXT,
|
||||
signature BLOB,
|
||||
CONSTRAINT uniqueness UNIQUE (event_id, key_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS event_origin_signatures_id ON event_origin_signatures (
|
||||
event_id
|
||||
);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS event_edge_hashes(
|
||||
event_id TEXT,
|
||||
prev_event_id TEXT,
|
||||
algorithm TEXT,
|
||||
hash BLOB,
|
||||
CONSTRAINT uniqueness UNIQUE (
|
||||
event_id, prev_event_id, algorithm
|
||||
)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes(
|
||||
event_id
|
||||
);
|
@ -23,6 +23,7 @@ CREATE TABLE IF NOT EXISTS events(
|
||||
unrecognized_keys TEXT,
|
||||
processed BOOL NOT NULL,
|
||||
outlier BOOL NOT NULL,
|
||||
depth INTEGER DEFAULT 0 NOT NULL,
|
||||
CONSTRAINT ev_uniq UNIQUE (event_id)
|
||||
);
|
||||
|
||||
@ -84,80 +85,24 @@ CREATE TABLE IF NOT EXISTS topics(
|
||||
topic TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS topics_event_id ON topics(event_id);
|
||||
CREATE INDEX IF NOT EXISTS topics_room_id ON topics(room_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS room_names(
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
name TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS room_names_event_id ON room_names(event_id);
|
||||
CREATE INDEX IF NOT EXISTS room_names_room_id ON room_names(room_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS rooms(
|
||||
room_id TEXT PRIMARY KEY NOT NULL,
|
||||
is_public INTEGER,
|
||||
creator TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS room_join_rules(
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
join_rule TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS room_join_rules_event_id ON room_join_rules(event_id);
|
||||
CREATE INDEX IF NOT EXISTS room_join_rules_room_id ON room_join_rules(room_id);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS room_power_levels(
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
level INTEGER NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS room_power_levels_event_id ON room_power_levels(event_id);
|
||||
CREATE INDEX IF NOT EXISTS room_power_levels_room_id ON room_power_levels(room_id);
|
||||
CREATE INDEX IF NOT EXISTS room_power_levels_room_user ON room_power_levels(room_id, user_id);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS room_default_levels(
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
level INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS room_default_levels_event_id ON room_default_levels(event_id);
|
||||
CREATE INDEX IF NOT EXISTS room_default_levels_room_id ON room_default_levels(room_id);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS room_add_state_levels(
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
level INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS room_add_state_levels_event_id ON room_add_state_levels(event_id);
|
||||
CREATE INDEX IF NOT EXISTS room_add_state_levels_room_id ON room_add_state_levels(room_id);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS room_send_event_levels(
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
level INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS room_send_event_levels_event_id ON room_send_event_levels(event_id);
|
||||
CREATE INDEX IF NOT EXISTS room_send_event_levels_room_id ON room_send_event_levels(room_id);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS room_ops_levels(
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
ban_level INTEGER,
|
||||
kick_level INTEGER,
|
||||
redact_level INTEGER
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS room_ops_levels_event_id ON room_ops_levels(event_id);
|
||||
CREATE INDEX IF NOT EXISTS room_ops_levels_room_id ON room_ops_levels(room_id);
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS room_hosts(
|
||||
room_id TEXT NOT NULL,
|
||||
host TEXT NOT NULL,
|
||||
|
@ -1,106 +0,0 @@
|
||||
/* Copyright 2014 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
-- Stores pdus and their content
|
||||
CREATE TABLE IF NOT EXISTS pdus(
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
context TEXT,
|
||||
pdu_type TEXT,
|
||||
ts INTEGER,
|
||||
depth INTEGER DEFAULT 0 NOT NULL,
|
||||
is_state BOOL,
|
||||
content_json TEXT,
|
||||
unrecognized_keys TEXT,
|
||||
outlier BOOL NOT NULL,
|
||||
have_processed BOOL,
|
||||
CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
|
||||
);
|
||||
|
||||
-- Stores what the current state pdu is for a given (context, pdu_type, key) tuple
|
||||
CREATE TABLE IF NOT EXISTS state_pdus(
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
context TEXT,
|
||||
pdu_type TEXT,
|
||||
state_key TEXT,
|
||||
power_level TEXT,
|
||||
prev_state_id TEXT,
|
||||
prev_state_origin TEXT,
|
||||
CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
|
||||
CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS current_state(
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
context TEXT,
|
||||
pdu_type TEXT,
|
||||
state_key TEXT,
|
||||
CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
|
||||
CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE
|
||||
);
|
||||
|
||||
-- Stores where each pdu we want to send should be sent and the delivery status.
|
||||
create TABLE IF NOT EXISTS pdu_destinations(
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
destination TEXT,
|
||||
delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
|
||||
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pdu_forward_extremities(
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
context TEXT,
|
||||
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pdu_backward_extremities(
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
context TEXT,
|
||||
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pdu_edges(
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
prev_pdu_id TEXT,
|
||||
prev_origin TEXT,
|
||||
context TEXT,
|
||||
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS context_depth(
|
||||
context TEXT,
|
||||
min_depth INTEGER,
|
||||
CONSTRAINT uniqueness UNIQUE (context)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context);
|
||||
|
||||
|
||||
CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin);
|
||||
-- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context);
|
||||
CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context);
|
33
synapse/storage/schema/state.sql
Normal file
33
synapse/storage/schema/state.sql
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2014 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS state_groups(
|
||||
id INTEGER PRIMARY KEY,
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS state_groups_state(
|
||||
state_group INTEGER NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
state_key TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS event_to_state_groups(
|
||||
event_id TEXT NOT NULL,
|
||||
state_group INTEGER NOT NULL
|
||||
);
|
177
synapse/storage/signatures.py
Normal file
177
synapse/storage/signatures.py
Normal file
@ -0,0 +1,177 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from _base import SQLBaseStore
|
||||
|
||||
|
||||
class SignatureStore(SQLBaseStore):
|
||||
"""Persistence for event signatures and hashes"""
|
||||
|
||||
def _get_event_content_hashes_txn(self, txn, event_id):
|
||||
"""Get all the hashes for a given Event.
|
||||
Args:
|
||||
txn (cursor):
|
||||
event_id (str): Id for the Event.
|
||||
Returns:
|
||||
A dict of algorithm -> hash.
|
||||
"""
|
||||
query = (
|
||||
"SELECT algorithm, hash"
|
||||
" FROM event_content_hashes"
|
||||
" WHERE event_id = ?"
|
||||
)
|
||||
txn.execute(query, (event_id, ))
|
||||
return dict(txn.fetchall())
|
||||
|
||||
def _store_event_content_hash_txn(self, txn, event_id, algorithm,
|
||||
hash_bytes):
|
||||
"""Store a hash for a Event
|
||||
Args:
|
||||
txn (cursor):
|
||||
event_id (str): Id for the Event.
|
||||
algorithm (str): Hashing algorithm.
|
||||
hash_bytes (bytes): Hash function output bytes.
|
||||
"""
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"event_content_hashes",
|
||||
{
|
||||
"event_id": event_id,
|
||||
"algorithm": algorithm,
|
||||
"hash": buffer(hash_bytes),
|
||||
},
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
def get_event_reference_hashes(self, event_ids):
|
||||
def f(txn):
|
||||
return [
|
||||
self._get_event_reference_hashes_txn(txn, ev)
|
||||
for ev in event_ids
|
||||
]
|
||||
|
||||
return self.runInteraction(
|
||||
"get_event_reference_hashes",
|
||||
f
|
||||
)
|
||||
|
||||
def _get_event_reference_hashes_txn(self, txn, event_id):
|
||||
"""Get all the hashes for a given PDU.
|
||||
Args:
|
||||
txn (cursor):
|
||||
event_id (str): Id for the Event.
|
||||
Returns:
|
||||
A dict of algorithm -> hash.
|
||||
"""
|
||||
query = (
|
||||
"SELECT algorithm, hash"
|
||||
" FROM event_reference_hashes"
|
||||
" WHERE event_id = ?"
|
||||
)
|
||||
txn.execute(query, (event_id, ))
|
||||
return dict(txn.fetchall())
|
||||
|
||||
def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
|
||||
hash_bytes):
|
||||
"""Store a hash for a PDU
|
||||
Args:
|
||||
txn (cursor):
|
||||
event_id (str): Id for the Event.
|
||||
algorithm (str): Hashing algorithm.
|
||||
hash_bytes (bytes): Hash function output bytes.
|
||||
"""
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"event_reference_hashes",
|
||||
{
|
||||
"event_id": event_id,
|
||||
"algorithm": algorithm,
|
||||
"hash": buffer(hash_bytes),
|
||||
},
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
|
||||
def _get_event_origin_signatures_txn(self, txn, event_id):
|
||||
"""Get all the signatures for a given PDU.
|
||||
Args:
|
||||
txn (cursor):
|
||||
event_id (str): Id for the Event.
|
||||
Returns:
|
||||
A dict of key_id -> signature_bytes.
|
||||
"""
|
||||
query = (
|
||||
"SELECT key_id, signature"
|
||||
" FROM event_origin_signatures"
|
||||
" WHERE event_id = ? "
|
||||
)
|
||||
txn.execute(query, (event_id, ))
|
||||
return dict(txn.fetchall())
|
||||
|
||||
def _store_event_origin_signature_txn(self, txn, event_id, origin, key_id,
|
||||
signature_bytes):
|
||||
"""Store a signature from the origin server for a PDU.
|
||||
Args:
|
||||
txn (cursor):
|
||||
event_id (str): Id for the Event.
|
||||
origin (str): origin of the Event.
|
||||
key_id (str): Id for the signing key.
|
||||
signature (bytes): The signature.
|
||||
"""
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"event_origin_signatures",
|
||||
{
|
||||
"event_id": event_id,
|
||||
"origin": origin,
|
||||
"key_id": key_id,
|
||||
"signature": buffer(signature_bytes),
|
||||
},
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
def _get_prev_event_hashes_txn(self, txn, event_id):
|
||||
"""Get all the hashes for previous PDUs of a PDU
|
||||
Args:
|
||||
txn (cursor):
|
||||
event_id (str): Id for the Event.
|
||||
Returns:
|
||||
dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
|
||||
"""
|
||||
query = (
|
||||
"SELECT prev_event_id, algorithm, hash"
|
||||
" FROM event_edge_hashes"
|
||||
" WHERE event_id = ?"
|
||||
)
|
||||
txn.execute(query, (event_id, ))
|
||||
results = {}
|
||||
for prev_event_id, algorithm, hash_bytes in txn.fetchall():
|
||||
hashes = results.setdefault(prev_event_id, {})
|
||||
hashes[algorithm] = hash_bytes
|
||||
return results
|
||||
|
||||
def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id,
|
||||
algorithm, hash_bytes):
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"event_edge_hashes",
|
||||
{
|
||||
"event_id": event_id,
|
||||
"prev_event_id": prev_event_id,
|
||||
"algorithm": algorithm,
|
||||
"hash": buffer(hash_bytes),
|
||||
},
|
||||
or_ignore=True,
|
||||
)
|
96
synapse/storage/state.py
Normal file
96
synapse/storage/state.py
Normal file
@ -0,0 +1,96 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from twisted.internet import defer
|
||||
|
||||
|
||||
class StateStore(SQLBaseStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_groups(self, event_ids):
|
||||
groups = set()
|
||||
for event_id in event_ids:
|
||||
group = yield self._simple_select_one_onecol(
|
||||
table="event_to_state_groups",
|
||||
keyvalues={"event_id": event_id},
|
||||
retcol="state_group",
|
||||
allow_none=True,
|
||||
)
|
||||
if group:
|
||||
groups.add(group)
|
||||
|
||||
res = {}
|
||||
for group in groups:
|
||||
state_ids = yield self._simple_select_onecol(
|
||||
table="state_groups_state",
|
||||
keyvalues={"state_group": group},
|
||||
retcol="event_id",
|
||||
)
|
||||
state = []
|
||||
for state_id in state_ids:
|
||||
s = yield self.get_event(
|
||||
state_id,
|
||||
allow_none=True,
|
||||
)
|
||||
if s:
|
||||
state.append(s)
|
||||
|
||||
res[group] = state
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
def store_state_groups(self, event):
|
||||
return self.runInteraction(
|
||||
"store_state_groups",
|
||||
self._store_state_groups_txn, event
|
||||
)
|
||||
|
||||
def _store_state_groups_txn(self, txn, event):
|
||||
if not event.state_events:
|
||||
return
|
||||
|
||||
state_group = event.state_group
|
||||
if not state_group:
|
||||
state_group = self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_groups",
|
||||
values={
|
||||
"room_id": event.room_id,
|
||||
"event_id": event.event_id,
|
||||
}
|
||||
)
|
||||
|
||||
for state in event.state_events.values():
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
values={
|
||||
"state_group": state_group,
|
||||
"room_id": state.room_id,
|
||||
"type": state.type,
|
||||
"state_key": state.state_key,
|
||||
"event_id": state.event_id,
|
||||
}
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_to_state_groups",
|
||||
values={
|
||||
"state_group": state_group,
|
||||
"event_id": event.event_id,
|
||||
}
|
||||
)
|
@ -177,10 +177,9 @@ class StreamStore(SQLBaseStore):
|
||||
|
||||
sql = (
|
||||
"SELECT *, (%(redacted)s) AS redacted FROM events AS e WHERE "
|
||||
"((room_id IN (%(current)s)) OR "
|
||||
"(e.outlier = 0 AND (room_id IN (%(current)s)) OR "
|
||||
"(event_id IN (%(invites)s))) "
|
||||
"AND e.stream_ordering > ? AND e.stream_ordering <= ? "
|
||||
"AND e.outlier = 0 "
|
||||
"ORDER BY stream_ordering ASC LIMIT %(limit)d "
|
||||
) % {
|
||||
"redacted": del_sql,
|
||||
@ -309,7 +308,10 @@ class StreamStore(SQLBaseStore):
|
||||
defer.returnValue(ret)
|
||||
|
||||
def get_room_events_max_id(self):
|
||||
return self.runInteraction(self._get_room_events_max_id_txn)
|
||||
return self.runInteraction(
|
||||
"get_room_events_max_id",
|
||||
self._get_room_events_max_id_txn
|
||||
)
|
||||
|
||||
def _get_room_events_max_id_txn(self, txn):
|
||||
txn.execute(
|
||||
|
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore, Table
|
||||
from .pdu import PdusTable
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
@ -42,6 +41,7 @@ class TransactionStore(SQLBaseStore):
|
||||
"""
|
||||
|
||||
return self.runInteraction(
|
||||
"get_received_txn_response",
|
||||
self._get_received_txn_response, transaction_id, origin
|
||||
)
|
||||
|
||||
@ -73,6 +73,7 @@ class TransactionStore(SQLBaseStore):
|
||||
"""
|
||||
|
||||
return self.runInteraction(
|
||||
"set_received_txn_response",
|
||||
self._set_received_txn_response,
|
||||
transaction_id, origin, code, response_dict
|
||||
)
|
||||
@ -88,7 +89,7 @@ class TransactionStore(SQLBaseStore):
|
||||
txn.execute(query, (code, response_json, transaction_id, origin))
|
||||
|
||||
def prep_send_transaction(self, transaction_id, destination,
|
||||
origin_server_ts, pdu_list):
|
||||
origin_server_ts):
|
||||
"""Persists an outgoing transaction and calculates the values for the
|
||||
previous transaction id list.
|
||||
|
||||
@ -99,19 +100,19 @@ class TransactionStore(SQLBaseStore):
|
||||
transaction_id (str)
|
||||
destination (str)
|
||||
origin_server_ts (int)
|
||||
pdu_list (list)
|
||||
|
||||
Returns:
|
||||
list: A list of previous transaction ids.
|
||||
"""
|
||||
|
||||
return self.runInteraction(
|
||||
"prep_send_transaction",
|
||||
self._prep_send_transaction,
|
||||
transaction_id, destination, origin_server_ts, pdu_list
|
||||
transaction_id, destination, origin_server_ts
|
||||
)
|
||||
|
||||
def _prep_send_transaction(self, txn, transaction_id, destination,
|
||||
origin_server_ts, pdu_list):
|
||||
origin_server_ts):
|
||||
|
||||
# First we find out what the prev_txs should be.
|
||||
# Since we know that we are only sending one transaction at a time,
|
||||
@ -139,15 +140,15 @@ class TransactionStore(SQLBaseStore):
|
||||
|
||||
# Update the tx id -> pdu id mapping
|
||||
|
||||
values = [
|
||||
(transaction_id, destination, pdu[0], pdu[1])
|
||||
for pdu in pdu_list
|
||||
]
|
||||
|
||||
logger.debug("Inserting: %s", repr(values))
|
||||
|
||||
query = TransactionsToPduTable.insert_statement()
|
||||
txn.executemany(query, values)
|
||||
# values = [
|
||||
# (transaction_id, destination, pdu[0], pdu[1])
|
||||
# for pdu in pdu_list
|
||||
# ]
|
||||
#
|
||||
# logger.debug("Inserting: %s", repr(values))
|
||||
#
|
||||
# query = TransactionsToPduTable.insert_statement()
|
||||
# txn.executemany(query, values)
|
||||
|
||||
return prev_txns
|
||||
|
||||
@ -161,6 +162,7 @@ class TransactionStore(SQLBaseStore):
|
||||
response_json (str)
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"delivered_txn",
|
||||
self._delivered_txn,
|
||||
transaction_id, destination, code, response_dict
|
||||
)
|
||||
@ -186,6 +188,7 @@ class TransactionStore(SQLBaseStore):
|
||||
list: A list of `ReceivedTransactionsTable.EntryType`
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_transactions_after",
|
||||
self._get_transactions_after, transaction_id, destination
|
||||
)
|
||||
|
||||
@ -202,49 +205,6 @@ class TransactionStore(SQLBaseStore):
|
||||
|
||||
return ReceivedTransactionsTable.decode_results(txn.fetchall())
|
||||
|
||||
def get_pdus_after_transaction(self, transaction_id, destination):
|
||||
"""For a given local transaction_id that we sent to a given destination
|
||||
home server, return a list of PDUs that were sent to that destination
|
||||
after it.
|
||||
|
||||
Args:
|
||||
txn
|
||||
transaction_id (str)
|
||||
destination (str)
|
||||
|
||||
Returns
|
||||
list: A list of PduTuple
|
||||
"""
|
||||
return self.runInteraction(
|
||||
self._get_pdus_after_transaction,
|
||||
transaction_id, destination
|
||||
)
|
||||
|
||||
def _get_pdus_after_transaction(self, txn, transaction_id, destination):
|
||||
|
||||
# Query that first get's all transaction_ids with an id greater than
|
||||
# the one given from the `sent_transactions` table. Then JOIN on this
|
||||
# from the `tx->pdu` table to get a list of (pdu_id, origin) that
|
||||
# specify the pdus that were sent in those transactions.
|
||||
query = (
|
||||
"SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp "
|
||||
"INNER JOIN %(sent_tx)s as st "
|
||||
"ON tp.transaction_id = st.transaction_id "
|
||||
"AND tp.destination = st.destination "
|
||||
"WHERE st.id > ("
|
||||
"SELECT id FROM %(sent_tx)s "
|
||||
"WHERE transaction_id = ? AND destination = ?"
|
||||
) % {
|
||||
"tx_pdu": TransactionsToPduTable.table_name,
|
||||
"sent_tx": SentTransactions.table_name,
|
||||
}
|
||||
|
||||
txn.execute(query, (transaction_id, destination))
|
||||
|
||||
pdus = PdusTable.decode_results(txn.fetchall())
|
||||
|
||||
return self._get_pdu_tuples(txn, pdus)
|
||||
|
||||
|
||||
class ReceivedTransactionsTable(Table):
|
||||
table_name = "received_transactions"
|
||||
|
@ -78,6 +78,11 @@ class DomainSpecificString(
|
||||
"""Create a structure on the local domain"""
|
||||
return cls(localpart=localpart, domain=hs.hostname, is_mine=True)
|
||||
|
||||
@classmethod
|
||||
def create(cls, localpart, domain, hs):
|
||||
is_mine = domain == hs.hostname
|
||||
return cls(localpart=localpart, domain=domain, is_mine=is_mine)
|
||||
|
||||
|
||||
class UserID(DomainSpecificString):
|
||||
"""Structure representing a user ID."""
|
||||
@ -94,6 +99,11 @@ class RoomID(DomainSpecificString):
|
||||
SIGIL = "!"
|
||||
|
||||
|
||||
class EventID(DomainSpecificString):
|
||||
"""Structure representing an event id. """
|
||||
SIGIL = "$"
|
||||
|
||||
|
||||
class StreamToken(
|
||||
namedtuple(
|
||||
"Token",
|
||||
|
@ -21,3 +21,10 @@ def sleep(seconds):
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(seconds, d.callback, seconds)
|
||||
return d
|
||||
|
||||
|
||||
def run_on_reactor():
|
||||
""" This will cause the rest of the function to be invoked upon the next
|
||||
iteration of the main loop
|
||||
"""
|
||||
return sleep(0)
|
@ -80,7 +80,7 @@ class JsonEncodedObject(object):
|
||||
|
||||
def get_full_dict(self):
|
||||
d = {
|
||||
k: v for (k, v) in self.__dict__.items()
|
||||
k: _encode(v) for (k, v) in self.__dict__.items()
|
||||
if k in self.valid_keys or k in self.internal_keys
|
||||
}
|
||||
d.update(self.unrecognized_keys)
|
||||
|
@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.api.events import SynapseEvent
|
||||
from synapse.api.events.validator import EventValidator
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
from tests import unittest
|
||||
|
||||
@ -21,7 +23,7 @@ from tests import unittest
|
||||
class SynapseTemplateCheckTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
self.validator = EventValidator(None)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
@ -38,22 +40,28 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
|
||||
}
|
||||
|
||||
event = MockSynapseEvent(template)
|
||||
self.assertTrue(event.check_json(content, raises=False))
|
||||
event.content = content
|
||||
self.assertTrue(self.validator.validate(event))
|
||||
|
||||
content = {
|
||||
"person": {"name": "bob"},
|
||||
"friends": ["jill"],
|
||||
"enemies": ["mike"]
|
||||
}
|
||||
event = MockSynapseEvent(template)
|
||||
self.assertTrue(event.check_json(content, raises=False))
|
||||
event.content = content
|
||||
self.assertTrue(self.validator.validate(event))
|
||||
|
||||
content = {
|
||||
"person": {"name": "bob"},
|
||||
# missing friends
|
||||
"enemies": ["mike", "jill"]
|
||||
}
|
||||
self.assertFalse(event.check_json(content, raises=False))
|
||||
event.content = content
|
||||
self.assertRaises(
|
||||
SynapseError,
|
||||
self.validator.validate,
|
||||
event
|
||||
)
|
||||
|
||||
def test_lists(self):
|
||||
template = {
|
||||
@ -67,13 +75,19 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
|
||||
}
|
||||
|
||||
event = MockSynapseEvent(template)
|
||||
self.assertFalse(event.check_json(content, raises=False))
|
||||
event.content = content
|
||||
self.assertRaises(
|
||||
SynapseError,
|
||||
self.validator.validate,
|
||||
event
|
||||
)
|
||||
|
||||
content = {
|
||||
"person": {"name": "bob"},
|
||||
"friends": [{"name": "jill"}, {"name": "mike"}]
|
||||
}
|
||||
self.assertTrue(event.check_json(content, raises=False))
|
||||
event.content = content
|
||||
self.assertTrue(self.validator.validate(event))
|
||||
|
||||
def test_nested_lists(self):
|
||||
template = {
|
||||
@ -103,7 +117,12 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
|
||||
}
|
||||
|
||||
event = MockSynapseEvent(template)
|
||||
self.assertFalse(event.check_json(content, raises=False))
|
||||
event.content = content
|
||||
self.assertRaises(
|
||||
SynapseError,
|
||||
self.validator.validate,
|
||||
event
|
||||
)
|
||||
|
||||
content = {
|
||||
"results": {
|
||||
@ -117,7 +136,8 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
|
||||
]
|
||||
}
|
||||
}
|
||||
self.assertTrue(event.check_json(content, raises=False))
|
||||
event.content = content
|
||||
self.assertTrue(self.validator.validate(event))
|
||||
|
||||
def test_nested_keys(self):
|
||||
template = {
|
||||
@ -145,7 +165,8 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
|
||||
self.assertTrue(event.check_json(content, raises=False))
|
||||
event.content = content
|
||||
self.assertTrue(self.validator.validate(event))
|
||||
|
||||
content = {
|
||||
"person": {
|
||||
@ -159,7 +180,12 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
|
||||
self.assertFalse(event.check_json(content, raises=False))
|
||||
event.content = content
|
||||
self.assertRaises(
|
||||
SynapseError,
|
||||
self.validator.validate,
|
||||
event
|
||||
)
|
||||
|
||||
content = {
|
||||
"person": {
|
||||
@ -173,7 +199,12 @@ class SynapseTemplateCheckTestCase(unittest.TestCase):
|
||||
}
|
||||
}
|
||||
|
||||
self.assertFalse(event.check_json(content, raises=False))
|
||||
event.content = content
|
||||
self.assertRaises(
|
||||
SynapseError,
|
||||
self.validator.validate,
|
||||
event
|
||||
)
|
||||
|
||||
|
||||
class MockSynapseEvent(SynapseEvent):
|
||||
|
@ -24,7 +24,6 @@ from ..utils import MockHttpResource, MockClock, MockKey
|
||||
from synapse.server import HomeServer
|
||||
from synapse.federation import initialize_http_replication
|
||||
from synapse.federation.units import Pdu
|
||||
from synapse.storage.pdu import PduTuple, PduEntry
|
||||
|
||||
|
||||
def make_pdu(prev_pdus=[], **kwargs):
|
||||
@ -41,7 +40,7 @@ def make_pdu(prev_pdus=[], **kwargs):
|
||||
}
|
||||
pdu_fields.update(kwargs)
|
||||
|
||||
return PduTuple(PduEntry(**pdu_fields), prev_pdus)
|
||||
return Pdu(prev_pdus=prev_pdus, **pdu_fields)
|
||||
|
||||
|
||||
class FederationTestCase(unittest.TestCase):
|
||||
@ -52,177 +51,185 @@ class FederationTestCase(unittest.TestCase):
|
||||
"put_json",
|
||||
])
|
||||
self.mock_persistence = Mock(spec=[
|
||||
"get_current_state_for_context",
|
||||
"get_pdu",
|
||||
"persist_event",
|
||||
"update_min_depth_for_context",
|
||||
"prep_send_transaction",
|
||||
"delivered_txn",
|
||||
"get_received_txn_response",
|
||||
"set_received_txn_response",
|
||||
])
|
||||
self.mock_persistence.get_received_txn_response.return_value = (
|
||||
defer.succeed(None)
|
||||
defer.succeed(None)
|
||||
)
|
||||
self.mock_config = Mock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
self.clock = MockClock()
|
||||
hs = HomeServer("test",
|
||||
resource_for_federation=self.mock_resource,
|
||||
http_client=self.mock_http_client,
|
||||
db_pool=None,
|
||||
datastore=self.mock_persistence,
|
||||
clock=self.clock,
|
||||
config=self.mock_config,
|
||||
keyring=Mock(),
|
||||
hs = HomeServer(
|
||||
"test",
|
||||
resource_for_federation=self.mock_resource,
|
||||
http_client=self.mock_http_client,
|
||||
db_pool=None,
|
||||
datastore=self.mock_persistence,
|
||||
clock=self.clock,
|
||||
config=self.mock_config,
|
||||
keyring=Mock(),
|
||||
)
|
||||
self.federation = initialize_http_replication(hs)
|
||||
self.distributor = hs.get_distributor()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_state(self):
|
||||
self.mock_persistence.get_current_state_for_context.return_value = (
|
||||
defer.succeed([])
|
||||
)
|
||||
mock_handler = Mock(spec=[
|
||||
"get_state_for_pdu",
|
||||
])
|
||||
|
||||
self.federation.set_handler(mock_handler)
|
||||
|
||||
mock_handler.get_state_for_pdu.return_value = defer.succeed([])
|
||||
|
||||
# Empty context initially
|
||||
(code, response) = yield self.mock_resource.trigger("GET",
|
||||
"/_matrix/federation/v1/state/my-context/", None)
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"GET",
|
||||
"/_matrix/federation/v1/state/my-context/",
|
||||
None
|
||||
)
|
||||
self.assertEquals(200, code)
|
||||
self.assertFalse(response["pdus"])
|
||||
|
||||
# Now lets give the context some state
|
||||
self.mock_persistence.get_current_state_for_context.return_value = (
|
||||
mock_handler.get_state_for_pdu.return_value = (
|
||||
defer.succeed([
|
||||
make_pdu(
|
||||
pdu_id="the-pdu-id",
|
||||
event_id="the-pdu-id",
|
||||
origin="red",
|
||||
context="my-context",
|
||||
pdu_type="m.topic",
|
||||
ts=123456789000,
|
||||
room_id="my-context",
|
||||
type="m.topic",
|
||||
origin_server_ts=123456789000,
|
||||
depth=1,
|
||||
is_state=True,
|
||||
content_json='{"topic":"The topic"}',
|
||||
content={"topic": "The topic"},
|
||||
state_key="",
|
||||
power_level=1000,
|
||||
prev_state_id="last-pdu-id",
|
||||
prev_state_origin="blue",
|
||||
prev_state="last-pdu-id",
|
||||
),
|
||||
])
|
||||
)
|
||||
|
||||
(code, response) = yield self.mock_resource.trigger("GET",
|
||||
"/_matrix/federation/v1/state/my-context/", None)
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"GET",
|
||||
"/_matrix/federation/v1/state/my-context/",
|
||||
None
|
||||
)
|
||||
self.assertEquals(200, code)
|
||||
self.assertEquals(1, len(response["pdus"]))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_pdu(self):
|
||||
self.mock_persistence.get_pdu.return_value = (
|
||||
mock_handler = Mock(spec=[
|
||||
"get_persisted_pdu",
|
||||
])
|
||||
|
||||
self.federation.set_handler(mock_handler)
|
||||
|
||||
mock_handler.get_persisted_pdu.return_value = (
|
||||
defer.succeed(None)
|
||||
)
|
||||
|
||||
(code, response) = yield self.mock_resource.trigger("GET",
|
||||
"/_matrix/federation/v1/pdu/red/abc123def456/", None)
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"GET",
|
||||
"/_matrix/federation/v1/event/abc123def456/",
|
||||
None
|
||||
)
|
||||
self.assertEquals(404, code)
|
||||
|
||||
# Now insert such a PDU
|
||||
self.mock_persistence.get_pdu.return_value = (
|
||||
mock_handler.get_persisted_pdu.return_value = (
|
||||
defer.succeed(
|
||||
make_pdu(
|
||||
pdu_id="abc123def456",
|
||||
event_id="abc123def456",
|
||||
origin="red",
|
||||
context="my-context",
|
||||
pdu_type="m.text",
|
||||
ts=123456789001,
|
||||
room_id="my-context",
|
||||
type="m.text",
|
||||
origin_server_ts=123456789001,
|
||||
depth=1,
|
||||
content_json='{"text":"Here is the message"}',
|
||||
content={"text": "Here is the message"},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
(code, response) = yield self.mock_resource.trigger("GET",
|
||||
"/_matrix/federation/v1/pdu/red/abc123def456/", None)
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"GET",
|
||||
"/_matrix/federation/v1/event/abc123def456/",
|
||||
None
|
||||
)
|
||||
self.assertEquals(200, code)
|
||||
self.assertEquals(1, len(response["pdus"]))
|
||||
self.assertEquals("m.text", response["pdus"][0]["pdu_type"])
|
||||
self.assertEquals("m.text", response["pdus"][0]["type"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_send_pdu(self):
|
||||
self.mock_http_client.put_json.return_value = defer.succeed(
|
||||
(200, "OK")
|
||||
(200, "OK")
|
||||
)
|
||||
|
||||
pdu = Pdu(
|
||||
pdu_id="abc123def456",
|
||||
origin="red",
|
||||
destinations=["remote"],
|
||||
context="my-context",
|
||||
origin_server_ts=123456789002,
|
||||
pdu_type="m.test",
|
||||
content={"testing": "content here"},
|
||||
depth=1,
|
||||
event_id="abc123def456",
|
||||
origin="red",
|
||||
room_id="my-context",
|
||||
type="m.text",
|
||||
origin_server_ts=123456789001,
|
||||
depth=1,
|
||||
content={"text": "Here is the message"},
|
||||
destinations=["remote"],
|
||||
)
|
||||
|
||||
yield self.federation.send_pdu(pdu)
|
||||
|
||||
self.mock_http_client.put_json.assert_called_with(
|
||||
"remote",
|
||||
path="/_matrix/federation/v1/send/1000000/",
|
||||
data={
|
||||
"origin_server_ts": 1000000,
|
||||
"origin": "test",
|
||||
"pdus": [
|
||||
{
|
||||
"origin": "red",
|
||||
"pdu_id": "abc123def456",
|
||||
"prev_pdus": [],
|
||||
"origin_server_ts": 123456789002,
|
||||
"context": "my-context",
|
||||
"pdu_type": "m.test",
|
||||
"is_state": False,
|
||||
"content": {"testing": "content here"},
|
||||
"depth": 1,
|
||||
},
|
||||
]
|
||||
},
|
||||
json_data_callback=ANY,
|
||||
"remote",
|
||||
path="/_matrix/federation/v1/send/1000000/",
|
||||
data={
|
||||
"origin_server_ts": 1000000,
|
||||
"origin": "test",
|
||||
"pdus": [
|
||||
pdu.get_dict(),
|
||||
],
|
||||
'pdu_failures': [],
|
||||
},
|
||||
json_data_callback=ANY,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_send_edu(self):
|
||||
self.mock_http_client.put_json.return_value = defer.succeed(
|
||||
(200, "OK")
|
||||
(200, "OK")
|
||||
)
|
||||
|
||||
yield self.federation.send_edu(
|
||||
destination="remote",
|
||||
edu_type="m.test",
|
||||
content={"testing": "content here"},
|
||||
destination="remote",
|
||||
edu_type="m.test",
|
||||
content={"testing": "content here"},
|
||||
)
|
||||
|
||||
# MockClock ensures we can guess these timestamps
|
||||
self.mock_http_client.put_json.assert_called_with(
|
||||
"remote",
|
||||
path="/_matrix/federation/v1/send/1000000/",
|
||||
data={
|
||||
"origin": "test",
|
||||
"origin_server_ts": 1000000,
|
||||
"pdus": [],
|
||||
"edus": [
|
||||
{
|
||||
# TODO: SYN-103: Remove "origin" and "destination"
|
||||
"origin": "test",
|
||||
"destination": "remote",
|
||||
"edu_type": "m.test",
|
||||
"content": {"testing": "content here"},
|
||||
}
|
||||
],
|
||||
},
|
||||
json_data_callback=ANY,
|
||||
"remote",
|
||||
path="/_matrix/federation/v1/send/1000000/",
|
||||
data={
|
||||
"origin": "test",
|
||||
"origin_server_ts": 1000000,
|
||||
"pdus": [],
|
||||
"edus": [
|
||||
{
|
||||
# TODO: SYN-103: Remove "origin" and "destination"
|
||||
"origin": "test",
|
||||
"destination": "remote",
|
||||
"edu_type": "m.test",
|
||||
"content": {"testing": "content here"},
|
||||
}
|
||||
],
|
||||
'pdu_failures': [],
|
||||
},
|
||||
json_data_callback=ANY,
|
||||
)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_recv_edu(self):
|
||||
recv_observer = Mock()
|
||||
@ -230,24 +237,26 @@ class FederationTestCase(unittest.TestCase):
|
||||
|
||||
self.federation.register_edu_handler("m.test", recv_observer)
|
||||
|
||||
yield self.mock_resource.trigger("PUT",
|
||||
"/_matrix/federation/v1/send/1001000/",
|
||||
"""{
|
||||
"origin": "remote",
|
||||
"origin_server_ts": 1001000,
|
||||
"pdus": [],
|
||||
"edus": [
|
||||
{
|
||||
"origin": "remote",
|
||||
"destination": "test",
|
||||
"edu_type": "m.test",
|
||||
"content": {"testing": "reply here"}
|
||||
}
|
||||
]
|
||||
}""")
|
||||
yield self.mock_resource.trigger(
|
||||
"PUT",
|
||||
"/_matrix/federation/v1/send/1001000/",
|
||||
"""{
|
||||
"origin": "remote",
|
||||
"origin_server_ts": 1001000,
|
||||
"pdus": [],
|
||||
"edus": [
|
||||
{
|
||||
"origin": "remote",
|
||||
"destination": "test",
|
||||
"edu_type": "m.test",
|
||||
"content": {"testing": "reply here"}
|
||||
}
|
||||
]
|
||||
}"""
|
||||
)
|
||||
|
||||
recv_observer.assert_called_with(
|
||||
"remote", {"testing": "reply here"}
|
||||
"remote", {"testing": "reply here"}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -278,8 +287,11 @@ class FederationTestCase(unittest.TestCase):
|
||||
|
||||
self.federation.register_query_handler("a-question", recv_handler)
|
||||
|
||||
code, response = yield self.mock_resource.trigger("GET",
|
||||
"/_matrix/federation/v1/query/a-question?three=3&four=4", None)
|
||||
code, response = yield self.mock_resource.trigger(
|
||||
"GET",
|
||||
"/_matrix/federation/v1/query/a-question?three=3&four=4",
|
||||
None
|
||||
)
|
||||
|
||||
self.assertEquals(200, code)
|
||||
self.assertEquals({"another": "response"}, response)
|
||||
|
@ -1,160 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from tests import unittest
|
||||
|
||||
from synapse.federation.pdu_codec import (
|
||||
PduCodec, encode_event_id, decode_event_id
|
||||
)
|
||||
from synapse.federation.units import Pdu
|
||||
#from synapse.api.events.room import MessageEvent
|
||||
|
||||
from synapse.server import HomeServer
|
||||
|
||||
from mock import Mock
|
||||
|
||||
|
||||
class PduCodecTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.hs = HomeServer("blargle.net")
|
||||
self.event_factory = self.hs.get_event_factory()
|
||||
|
||||
self.codec = PduCodec(self.hs)
|
||||
|
||||
def test_decode_event_id(self):
|
||||
self.assertEquals(
|
||||
("foo", "bar.com"),
|
||||
decode_event_id("foo@bar.com", "A")
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
("foo", "bar.com"),
|
||||
decode_event_id("foo", "bar.com")
|
||||
)
|
||||
|
||||
def test_encode_event_id(self):
|
||||
self.assertEquals("A@B", encode_event_id("A", "B"))
|
||||
|
||||
def test_codec_event_id(self):
|
||||
event_id = "aa@bb.com"
|
||||
|
||||
self.assertEquals(
|
||||
event_id,
|
||||
encode_event_id(*decode_event_id(event_id, None))
|
||||
)
|
||||
|
||||
pdu_id = ("aa", "bb.com")
|
||||
|
||||
self.assertEquals(
|
||||
pdu_id,
|
||||
decode_event_id(encode_event_id(*pdu_id), None)
|
||||
)
|
||||
|
||||
def test_event_from_pdu(self):
|
||||
pdu = Pdu(
|
||||
pdu_id="foo",
|
||||
context="rooooom",
|
||||
pdu_type="m.room.message",
|
||||
origin="bar.com",
|
||||
origin_server_ts=12345,
|
||||
depth=5,
|
||||
prev_pdus=[("alice", "bob.com")],
|
||||
is_state=False,
|
||||
content={"msgtype": u"test"},
|
||||
)
|
||||
|
||||
event = self.codec.event_from_pdu(pdu)
|
||||
|
||||
self.assertEquals("foo@bar.com", event.event_id)
|
||||
self.assertEquals(pdu.context, event.room_id)
|
||||
self.assertEquals(pdu.is_state, event.is_state)
|
||||
self.assertEquals(pdu.depth, event.depth)
|
||||
self.assertEquals(["alice@bob.com"], event.prev_events)
|
||||
self.assertEquals(pdu.content, event.content)
|
||||
|
||||
def test_pdu_from_event(self):
|
||||
event = self.event_factory.create_event(
|
||||
etype="m.room.message",
|
||||
event_id="gargh_id",
|
||||
room_id="rooom",
|
||||
user_id="sender",
|
||||
content={"msgtype": u"test"},
|
||||
)
|
||||
|
||||
pdu = self.codec.pdu_from_event(event)
|
||||
|
||||
self.assertEquals(event.event_id, pdu.pdu_id)
|
||||
self.assertEquals(self.hs.hostname, pdu.origin)
|
||||
self.assertEquals(event.room_id, pdu.context)
|
||||
self.assertEquals(event.content, pdu.content)
|
||||
self.assertEquals(event.type, pdu.pdu_type)
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype="m.room.message",
|
||||
event_id="gargh_id@bob.com",
|
||||
room_id="rooom",
|
||||
user_id="sender",
|
||||
content={"msgtype": u"test"},
|
||||
)
|
||||
|
||||
pdu = self.codec.pdu_from_event(event)
|
||||
|
||||
self.assertEquals("gargh_id", pdu.pdu_id)
|
||||
self.assertEquals("bob.com", pdu.origin)
|
||||
self.assertEquals(event.room_id, pdu.context)
|
||||
self.assertEquals(event.content, pdu.content)
|
||||
self.assertEquals(event.type, pdu.pdu_type)
|
||||
|
||||
def test_event_from_state_pdu(self):
|
||||
pdu = Pdu(
|
||||
pdu_id="foo",
|
||||
context="rooooom",
|
||||
pdu_type="m.room.topic",
|
||||
origin="bar.com",
|
||||
origin_server_ts=12345,
|
||||
depth=5,
|
||||
prev_pdus=[("alice", "bob.com")],
|
||||
is_state=True,
|
||||
content={"topic": u"test"},
|
||||
state_key="",
|
||||
)
|
||||
|
||||
event = self.codec.event_from_pdu(pdu)
|
||||
|
||||
self.assertEquals("foo@bar.com", event.event_id)
|
||||
self.assertEquals(pdu.context, event.room_id)
|
||||
self.assertEquals(pdu.is_state, event.is_state)
|
||||
self.assertEquals(pdu.depth, event.depth)
|
||||
self.assertEquals(["alice@bob.com"], event.prev_events)
|
||||
self.assertEquals(pdu.content, event.content)
|
||||
self.assertEquals(pdu.state_key, event.state_key)
|
||||
|
||||
def test_pdu_from_state_event(self):
|
||||
event = self.event_factory.create_event(
|
||||
etype="m.room.topic",
|
||||
event_id="gargh_id",
|
||||
room_id="rooom",
|
||||
user_id="sender",
|
||||
content={"topic": u"test"},
|
||||
)
|
||||
|
||||
pdu = self.codec.pdu_from_event(event)
|
||||
|
||||
self.assertEquals(event.event_id, pdu.pdu_id)
|
||||
self.assertEquals(self.hs.hostname, pdu.origin)
|
||||
self.assertEquals(event.room_id, pdu.context)
|
||||
self.assertEquals(event.content, pdu.content)
|
||||
self.assertEquals(event.type, pdu.pdu_type)
|
||||
self.assertEquals(event.state_key, pdu.state_key)
|
@ -21,9 +21,8 @@ from mock import Mock
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.handlers.directory import DirectoryHandler
|
||||
from synapse.storage.directory import RoomAliasMapping
|
||||
|
||||
from tests.utils import SQLiteMemoryDbPool
|
||||
from tests.utils import SQLiteMemoryDbPool, MockKey
|
||||
|
||||
|
||||
class DirectoryHandlers(object):
|
||||
@ -41,6 +40,7 @@ class DirectoryTestCase(unittest.TestCase):
|
||||
])
|
||||
|
||||
self.query_handlers = {}
|
||||
|
||||
def register_query_handler(query_type, handler):
|
||||
self.query_handlers[query_type] = handler
|
||||
self.mock_federation.register_query_handler = register_query_handler
|
||||
@ -48,11 +48,16 @@ class DirectoryTestCase(unittest.TestCase):
|
||||
db_pool = SQLiteMemoryDbPool()
|
||||
yield db_pool.prepare()
|
||||
|
||||
hs = HomeServer("test",
|
||||
self.mock_config = Mock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer(
|
||||
"test",
|
||||
db_pool=db_pool,
|
||||
http_client=None,
|
||||
resource_for_federation=Mock(),
|
||||
replication_layer=self.mock_federation,
|
||||
config=self.mock_config,
|
||||
)
|
||||
hs.handlers = DirectoryHandlers(hs)
|
||||
|
||||
|
@ -17,16 +17,15 @@ from twisted.internet import defer
|
||||
from tests import unittest
|
||||
|
||||
from synapse.api.events.room import (
|
||||
InviteJoinEvent, MessageEvent, RoomMemberEvent
|
||||
MessageEvent,
|
||||
)
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.handlers.federation import FederationHandler
|
||||
from synapse.server import HomeServer
|
||||
from synapse.federation.units import Pdu
|
||||
|
||||
from mock import NonCallableMock, ANY
|
||||
|
||||
from ..utils import get_mock_call_args, MockKey
|
||||
from ..utils import MockKey
|
||||
|
||||
|
||||
class FederationTestCase(unittest.TestCase):
|
||||
@ -36,6 +35,14 @@ class FederationTestCase(unittest.TestCase):
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
self.state_handler = NonCallableMock(spec_set=[
|
||||
"annotate_state_groups",
|
||||
])
|
||||
|
||||
self.auth = NonCallableMock(spec_set=[
|
||||
"check",
|
||||
])
|
||||
|
||||
self.hostname = "test"
|
||||
hs = HomeServer(
|
||||
self.hostname,
|
||||
@ -53,6 +60,8 @@ class FederationTestCase(unittest.TestCase):
|
||||
"federation_handler",
|
||||
]),
|
||||
config=self.mock_config,
|
||||
auth=self.auth,
|
||||
state_handler=self.state_handler,
|
||||
)
|
||||
|
||||
self.datastore = hs.get_datastore()
|
||||
@ -65,74 +74,35 @@ class FederationTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_msg(self):
|
||||
pdu = Pdu(
|
||||
pdu_type=MessageEvent.TYPE,
|
||||
context="foo",
|
||||
type=MessageEvent.TYPE,
|
||||
room_id="foo",
|
||||
content={"msgtype": u"fooo"},
|
||||
origin_server_ts=0,
|
||||
pdu_id="a",
|
||||
event_id="$a:b",
|
||||
origin="b",
|
||||
)
|
||||
|
||||
store_id = "ASD"
|
||||
self.datastore.persist_event.return_value = defer.succeed(store_id)
|
||||
self.datastore.persist_event.return_value = defer.succeed(None)
|
||||
self.datastore.get_room.return_value = defer.succeed(True)
|
||||
|
||||
self.state_handler.annotate_state_groups.return_value = (
|
||||
defer.succeed(False)
|
||||
)
|
||||
|
||||
yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
|
||||
|
||||
self.datastore.persist_event.assert_called_once_with(
|
||||
ANY, False, is_new_state=False
|
||||
)
|
||||
self.notifier.on_new_room_event.assert_called_once_with(ANY, extra_users=[])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_invite_join_target_this(self):
|
||||
room_id = "foo"
|
||||
user_id = "@bob:red"
|
||||
|
||||
pdu = Pdu(
|
||||
pdu_type=InviteJoinEvent.TYPE,
|
||||
user_id=user_id,
|
||||
target_host=self.hostname,
|
||||
context=room_id,
|
||||
content={},
|
||||
origin_server_ts=0,
|
||||
pdu_id="a",
|
||||
origin="b",
|
||||
self.state_handler.annotate_state_groups.assert_called_once_with(
|
||||
ANY,
|
||||
old_state=None,
|
||||
)
|
||||
|
||||
yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
|
||||
self.auth.check.assert_called_once_with(ANY, raises=True)
|
||||
|
||||
mem_handler = self.handlers.room_member_handler
|
||||
self.assertEquals(1, mem_handler.change_membership.call_count)
|
||||
call_args = get_mock_call_args(
|
||||
lambda event, do_auth: None,
|
||||
mem_handler.change_membership
|
||||
self.notifier.on_new_room_event.assert_called_once_with(
|
||||
ANY,
|
||||
extra_users=[]
|
||||
)
|
||||
self.assertEquals(False, call_args["do_auth"])
|
||||
|
||||
new_event = call_args["event"]
|
||||
self.assertEquals(RoomMemberEvent.TYPE, new_event.type)
|
||||
self.assertEquals(room_id, new_event.room_id)
|
||||
self.assertEquals(user_id, new_event.state_key)
|
||||
self.assertEquals(Membership.JOIN, new_event.membership)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_invite_join_target_other(self):
|
||||
room_id = "foo"
|
||||
user_id = "@bob:red"
|
||||
|
||||
pdu = Pdu(
|
||||
pdu_type=InviteJoinEvent.TYPE,
|
||||
user_id=user_id,
|
||||
state_key="@red:not%s" % self.hostname,
|
||||
context=room_id,
|
||||
content={},
|
||||
origin_server_ts=0,
|
||||
pdu_id="a",
|
||||
origin="b",
|
||||
)
|
||||
|
||||
yield self.handlers.federation_handler.on_receive_pdu(pdu, False)
|
||||
|
||||
mem_handler = self.handlers.room_member_handler
|
||||
self.assertEquals(0, mem_handler.change_membership.call_count)
|
||||
|
@ -51,6 +51,7 @@ def _expect_edu(destination, edu_type, content, origin="test"):
|
||||
"content": content,
|
||||
}
|
||||
],
|
||||
"pdu_failures": [],
|
||||
}
|
||||
|
||||
def _make_edu_json(origin, edu_type, content):
|
||||
|
@ -21,7 +21,7 @@ from twisted.internet import defer
|
||||
|
||||
from mock import Mock, call, ANY
|
||||
|
||||
from ..utils import MockClock
|
||||
from ..utils import MockClock, MockKey
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.api.constants import PresenceState
|
||||
@ -57,6 +57,9 @@ class PresenceAndProfileHandlers(object):
|
||||
class PresenceProfilelikeDataTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_config = Mock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer("test",
|
||||
clock=MockClock(),
|
||||
db_pool=None,
|
||||
@ -72,6 +75,7 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
|
||||
resource_for_federation=Mock(),
|
||||
http_client=None,
|
||||
replication_layer=MockReplication(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
hs.handlers = PresenceAndProfileHandlers(hs)
|
||||
|
||||
|
@ -24,7 +24,7 @@ from synapse.server import HomeServer
|
||||
from synapse.handlers.profile import ProfileHandler
|
||||
from synapse.api.constants import Membership
|
||||
|
||||
from tests.utils import SQLiteMemoryDbPool
|
||||
from tests.utils import SQLiteMemoryDbPool, MockKey
|
||||
|
||||
|
||||
class ProfileHandlers(object):
|
||||
@ -49,12 +49,16 @@ class ProfileTestCase(unittest.TestCase):
|
||||
db_pool = SQLiteMemoryDbPool()
|
||||
yield db_pool.prepare()
|
||||
|
||||
self.mock_config = Mock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer("test",
|
||||
db_pool=db_pool,
|
||||
http_client=None,
|
||||
handlers=None,
|
||||
resource_for_federation=Mock(),
|
||||
replication_layer=self.mock_federation,
|
||||
config=self.mock_config,
|
||||
)
|
||||
hs.handlers = ProfileHandlers(hs)
|
||||
|
||||
|
@ -18,7 +18,7 @@ from twisted.internet import defer
|
||||
from tests import unittest
|
||||
|
||||
from synapse.api.events.room import (
|
||||
InviteJoinEvent, RoomMemberEvent, RoomConfigEvent
|
||||
RoomMemberEvent,
|
||||
)
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler
|
||||
@ -34,6 +34,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
self.hostname = "red"
|
||||
hs = HomeServer(
|
||||
self.hostname,
|
||||
@ -57,13 +58,16 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||
"profile_handler",
|
||||
"federation_handler",
|
||||
]),
|
||||
auth=NonCallableMock(spec_set=["check"]),
|
||||
state_handler=NonCallableMock(spec_set=["handle_new_event"]),
|
||||
auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
|
||||
state_handler=NonCallableMock(spec_set=[
|
||||
"annotate_state_groups",
|
||||
]),
|
||||
config=self.mock_config,
|
||||
)
|
||||
|
||||
self.federation = NonCallableMock(spec_set=[
|
||||
"handle_new_event",
|
||||
"send_invite",
|
||||
"get_state_for_room",
|
||||
])
|
||||
|
||||
@ -106,7 +110,6 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||
|
||||
joined = ["red", "green"]
|
||||
|
||||
self.state_handler.handle_new_event.return_value = defer.succeed(True)
|
||||
self.datastore.get_joined_hosts_for_room.return_value = (
|
||||
defer.succeed(joined)
|
||||
)
|
||||
@ -114,18 +117,29 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||
store_id = "store_id_fooo"
|
||||
self.datastore.persist_event.return_value = defer.succeed(store_id)
|
||||
|
||||
self.datastore.get_room_member.return_value = defer.succeed(None)
|
||||
|
||||
event.state_events = {
|
||||
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
|
||||
user_id="@alice:green",
|
||||
room_id=room_id,
|
||||
),
|
||||
(RoomMemberEvent.TYPE, "@bob:red"): self._create_member(
|
||||
user_id="@bob:red",
|
||||
room_id=room_id,
|
||||
),
|
||||
(RoomMemberEvent.TYPE, target_user_id): event,
|
||||
}
|
||||
|
||||
# Actual invocation
|
||||
yield self.room_member_handler.change_membership(event)
|
||||
|
||||
self.state_handler.handle_new_event.assert_called_once_with(
|
||||
event, self.snapshot,
|
||||
)
|
||||
self.federation.handle_new_event.assert_called_once_with(
|
||||
event, self.snapshot,
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
set(["blue", "red", "green"]),
|
||||
set(["red", "green"]),
|
||||
set(event.destinations)
|
||||
)
|
||||
|
||||
@ -144,28 +158,19 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||
room_id = "!foo:red"
|
||||
user_id = "@bob:red"
|
||||
user = self.hs.parse_userid(user_id)
|
||||
target_user_id = "@bob:red"
|
||||
content = {"membership": Membership.JOIN}
|
||||
|
||||
event = self.hs.get_event_factory().create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
event = self._create_member(
|
||||
user_id=user_id,
|
||||
state_key=target_user_id,
|
||||
room_id=room_id,
|
||||
membership=Membership.JOIN,
|
||||
content=content,
|
||||
)
|
||||
|
||||
joined = ["red", "green"]
|
||||
|
||||
self.state_handler.handle_new_event.return_value = defer.succeed(True)
|
||||
|
||||
def get_joined(*args):
|
||||
return defer.succeed(joined)
|
||||
|
||||
self.datastore.get_joined_hosts_for_room.side_effect = get_joined
|
||||
|
||||
|
||||
store_id = "store_id_fooo"
|
||||
self.datastore.persist_event.return_value = defer.succeed(store_id)
|
||||
self.datastore.get_room.return_value = defer.succeed(1) # Not None.
|
||||
@ -178,12 +183,17 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||
join_signal_observer = Mock()
|
||||
self.distributor.observe("user_joined_room", join_signal_observer)
|
||||
|
||||
event.state_events = {
|
||||
(RoomMemberEvent.TYPE, "@alice:green"): self._create_member(
|
||||
user_id="@alice:green",
|
||||
room_id=room_id,
|
||||
),
|
||||
(RoomMemberEvent.TYPE, user_id): event,
|
||||
}
|
||||
|
||||
# Actual invocation
|
||||
yield self.room_member_handler.change_membership(event)
|
||||
|
||||
self.state_handler.handle_new_event.assert_called_once_with(
|
||||
event, self.snapshot
|
||||
)
|
||||
self.federation.handle_new_event.assert_called_once_with(
|
||||
event, self.snapshot
|
||||
)
|
||||
@ -197,138 +207,32 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||
event
|
||||
)
|
||||
self.notifier.on_new_room_event.assert_called_once_with(
|
||||
event, extra_users=[user])
|
||||
event, extra_users=[user]
|
||||
)
|
||||
|
||||
join_signal_observer.assert_called_with(
|
||||
user=user, room_id=room_id)
|
||||
user=user, room_id=room_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def STALE_test_invite_join(self):
|
||||
room_id = "foo"
|
||||
user_id = "@bob:red"
|
||||
target_user_id = "@bob:red"
|
||||
content = {"membership": Membership.JOIN}
|
||||
|
||||
event = self.hs.get_event_factory().create_event(
|
||||
def _create_member(self, user_id, room_id):
|
||||
return self.hs.get_event_factory().create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
user_id=user_id,
|
||||
target_user_id=target_user_id,
|
||||
state_key=user_id,
|
||||
room_id=room_id,
|
||||
membership=Membership.JOIN,
|
||||
content=content,
|
||||
content={"membership": Membership.JOIN},
|
||||
)
|
||||
|
||||
joined = ["red", "blue", "green"]
|
||||
|
||||
self.state_handler.handle_new_event.return_value = defer.succeed(True)
|
||||
self.datastore.get_joined_hosts_for_room.return_value = (
|
||||
defer.succeed(joined)
|
||||
)
|
||||
|
||||
store_id = "store_id_fooo"
|
||||
self.datastore.store_room_member.return_value = defer.succeed(store_id)
|
||||
self.datastore.get_room.return_value = defer.succeed(None)
|
||||
|
||||
prev_state = NonCallableMock(name="prev_state")
|
||||
prev_state.membership = Membership.INVITE
|
||||
prev_state.sender = "@foo:blue"
|
||||
self.datastore.get_room_member.return_value = defer.succeed(prev_state)
|
||||
|
||||
# Actual invocation
|
||||
yield self.room_member_handler.change_membership(event)
|
||||
|
||||
self.datastore.get_room_member.assert_called_once_with(
|
||||
target_user_id, room_id
|
||||
)
|
||||
|
||||
self.assertTrue(self.federation.handle_new_event.called)
|
||||
args = self.federation.handle_new_event.call_args[0]
|
||||
invite_join_event = args[0]
|
||||
|
||||
self.assertTrue(InviteJoinEvent.TYPE, invite_join_event.TYPE)
|
||||
self.assertTrue("blue", invite_join_event.target_host)
|
||||
self.assertTrue(room_id, invite_join_event.room_id)
|
||||
self.assertTrue(user_id, invite_join_event.user_id)
|
||||
self.assertFalse(hasattr(invite_join_event, "state_key"))
|
||||
|
||||
self.assertEquals(
|
||||
set(["blue"]),
|
||||
set(invite_join_event.destinations)
|
||||
)
|
||||
|
||||
self.federation.get_state_for_room.assert_called_once_with(
|
||||
"blue", room_id
|
||||
)
|
||||
|
||||
self.assertFalse(self.datastore.store_room_member.called)
|
||||
|
||||
self.assertFalse(self.notifier.on_new_room_event.called)
|
||||
self.assertFalse(self.state_handler.handle_new_event.called)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def STALE_test_invite_join_public(self):
|
||||
room_id = "#foo:blue"
|
||||
user_id = "@bob:red"
|
||||
target_user_id = "@bob:red"
|
||||
content = {"membership": Membership.JOIN}
|
||||
|
||||
event = self.hs.get_event_factory().create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
user_id=user_id,
|
||||
target_user_id=target_user_id,
|
||||
room_id=room_id,
|
||||
membership=Membership.JOIN,
|
||||
content=content,
|
||||
)
|
||||
|
||||
joined = ["red", "blue", "green"]
|
||||
|
||||
self.state_handler.handle_new_event.return_value = defer.succeed(True)
|
||||
self.datastore.get_joined_hosts_for_room.return_value = (
|
||||
defer.succeed(joined)
|
||||
)
|
||||
|
||||
store_id = "store_id_fooo"
|
||||
self.datastore.store_room_member.return_value = defer.succeed(store_id)
|
||||
self.datastore.get_room.return_value = defer.succeed(None)
|
||||
|
||||
prev_state = NonCallableMock(name="prev_state")
|
||||
prev_state.membership = Membership.INVITE
|
||||
prev_state.sender = "@foo:blue"
|
||||
self.datastore.get_room_member.return_value = defer.succeed(prev_state)
|
||||
|
||||
# Actual invocation
|
||||
yield self.room_member_handler.change_membership(event)
|
||||
|
||||
self.assertTrue(self.federation.handle_new_event.called)
|
||||
args = self.federation.handle_new_event.call_args[0]
|
||||
invite_join_event = args[0]
|
||||
|
||||
self.assertTrue(InviteJoinEvent.TYPE, invite_join_event.TYPE)
|
||||
self.assertTrue("blue", invite_join_event.target_host)
|
||||
self.assertTrue("foo", invite_join_event.room_id)
|
||||
self.assertTrue(user_id, invite_join_event.user_id)
|
||||
self.assertFalse(hasattr(invite_join_event, "state_key"))
|
||||
|
||||
self.assertEquals(
|
||||
set(["blue"]),
|
||||
set(invite_join_event.destinations)
|
||||
)
|
||||
|
||||
self.federation.get_state_for_room.assert_called_once_with(
|
||||
"blue", "foo"
|
||||
)
|
||||
|
||||
self.assertFalse(self.datastore.store_room_member.called)
|
||||
|
||||
self.assertFalse(self.notifier.on_new_room_event.called)
|
||||
self.assertFalse(self.state_handler.handle_new_event.called)
|
||||
|
||||
|
||||
class RoomCreationTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.hostname = "red"
|
||||
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer(
|
||||
self.hostname,
|
||||
db_pool=None,
|
||||
@ -345,12 +249,14 @@ class RoomCreationTest(unittest.TestCase):
|
||||
"room_member_handler",
|
||||
"federation_handler",
|
||||
]),
|
||||
auth=NonCallableMock(spec_set=["check"]),
|
||||
state_handler=NonCallableMock(spec_set=["handle_new_event"]),
|
||||
auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
|
||||
state_handler=NonCallableMock(spec_set=[
|
||||
"annotate_state_groups",
|
||||
]),
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=NonCallableMock(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
|
||||
self.federation = NonCallableMock(spec_set=[
|
||||
@ -373,6 +279,11 @@ class RoomCreationTest(unittest.TestCase):
|
||||
])
|
||||
self.room_member_handler = self.handlers.room_member_handler
|
||||
|
||||
def annotate(event):
|
||||
event.state_events = {}
|
||||
return defer.succeed(None)
|
||||
self.state_handler.annotate_state_groups.side_effect = annotate
|
||||
|
||||
def hosts(room):
|
||||
return defer.succeed([])
|
||||
self.datastore.get_joined_hosts_for_room.side_effect = hosts
|
||||
@ -400,6 +311,6 @@ class RoomCreationTest(unittest.TestCase):
|
||||
self.assertEquals(user_id, join_event.user_id)
|
||||
self.assertEquals(user_id, join_event.state_key)
|
||||
|
||||
self.assertTrue(self.state_handler.handle_new_event.called)
|
||||
self.assertTrue(self.state_handler.annotate_state_groups.called)
|
||||
|
||||
self.assertTrue(self.federation.handle_new_event.called)
|
||||
|
@ -40,6 +40,7 @@ def _expect_edu(destination, edu_type, content, origin="test"):
|
||||
"content": content,
|
||||
}
|
||||
],
|
||||
"pdu_failures": [],
|
||||
}
|
||||
|
||||
|
||||
|
@ -25,10 +25,7 @@ import synapse.rest.room
|
||||
|
||||
from synapse.server import HomeServer
|
||||
|
||||
# python imports
|
||||
import json
|
||||
|
||||
from ..utils import MockHttpResource, MemoryDataStore
|
||||
from ..utils import MockHttpResource, SQLiteMemoryDbPool, MockKey
|
||||
from .utils import RestTestCase
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
@ -49,7 +46,7 @@ class EventStreamPaginationApiTestCase(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def test_long_poll(self):
|
||||
def TODO_test_long_poll(self):
|
||||
# stream from 'end' key, send (self+other) message, expect message.
|
||||
|
||||
# stream from 'END', send (self+other) message, expect message.
|
||||
@ -64,7 +61,7 @@ class EventStreamPaginationApiTestCase(unittest.TestCase):
|
||||
|
||||
pass
|
||||
|
||||
def test_stream_forward(self):
|
||||
def TODO_test_stream_forward(self):
|
||||
# stream from START, expect injected items
|
||||
|
||||
# stream from 'start' key, expect same content
|
||||
@ -80,14 +77,14 @@ class EventStreamPaginationApiTestCase(unittest.TestCase):
|
||||
# returned as end key
|
||||
pass
|
||||
|
||||
def test_limits(self):
|
||||
def TODO_test_limits(self):
|
||||
# stream from a key, expect limit_num items
|
||||
|
||||
# stream from START, expect limit_num items
|
||||
|
||||
pass
|
||||
|
||||
def test_range(self):
|
||||
def TODO_test_range(self):
|
||||
# stream from key to key, expect X items
|
||||
|
||||
# stream from key to END, expect X items
|
||||
@ -97,7 +94,7 @@ class EventStreamPaginationApiTestCase(unittest.TestCase):
|
||||
# stream from START to END, expect all items
|
||||
pass
|
||||
|
||||
def test_direction(self):
|
||||
def TODO_test_direction(self):
|
||||
# stream from END to START and fwds, expect newest first
|
||||
|
||||
# stream from END to START and bwds, expect oldest first
|
||||
@ -116,19 +113,20 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
||||
def setUp(self):
|
||||
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
|
||||
|
||||
state_handler = Mock(spec=["handle_new_event"])
|
||||
state_handler.handle_new_event.return_value = True
|
||||
|
||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
db_pool = SQLiteMemoryDbPool()
|
||||
yield db_pool.prepare()
|
||||
|
||||
hs = HomeServer(
|
||||
"test",
|
||||
db_pool=None,
|
||||
db_pool=db_pool,
|
||||
http_client=None,
|
||||
replication_layer=Mock(),
|
||||
state_handler=state_handler,
|
||||
datastore=MemoryDataStore(),
|
||||
persistence_service=persistence_service,
|
||||
clock=Mock(spec=[
|
||||
"call_later",
|
||||
@ -139,7 +137,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=NonCallableMock(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
@ -148,6 +146,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
||||
hs.get_handlers().federation_handler = Mock()
|
||||
|
||||
hs.get_clock().time_msec.return_value = 1000000
|
||||
hs.get_clock().time.return_value = 1000
|
||||
|
||||
synapse.rest.register.register_servlets(hs, self.mock_resource)
|
||||
synapse.rest.events.register_servlets(hs, self.mock_resource)
|
||||
@ -172,12 +171,14 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
||||
def test_stream_basic_permissions(self):
|
||||
# invalid token, expect 403
|
||||
(code, response) = yield self.mock_resource.trigger_get(
|
||||
"/events?access_token=%s" % ("invalid" + self.token))
|
||||
"/events?access_token=%s" % ("invalid" + self.token, )
|
||||
)
|
||||
self.assertEquals(403, code, msg=str(response))
|
||||
|
||||
# valid token, expect content
|
||||
(code, response) = yield self.mock_resource.trigger_get(
|
||||
"/events?access_token=%s&timeout=0" % (self.token))
|
||||
"/events?access_token=%s&timeout=0" % (self.token,)
|
||||
)
|
||||
self.assertEquals(200, code, msg=str(response))
|
||||
self.assertTrue("chunk" in response)
|
||||
self.assertTrue("start" in response)
|
||||
@ -185,15 +186,23 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_stream_room_permissions(self):
|
||||
room_id = yield self.create_room_as(self.other_user,
|
||||
tok=self.other_token)
|
||||
room_id = yield self.create_room_as(
|
||||
self.other_user,
|
||||
tok=self.other_token
|
||||
)
|
||||
yield self.send(room_id, tok=self.other_token)
|
||||
|
||||
# invited to room (expect no content for room)
|
||||
yield self.invite(room_id, src=self.other_user, targ=self.user_id,
|
||||
tok=self.other_token)
|
||||
yield self.invite(
|
||||
room_id,
|
||||
src=self.other_user,
|
||||
targ=self.user_id,
|
||||
tok=self.other_token
|
||||
)
|
||||
|
||||
(code, response) = yield self.mock_resource.trigger_get(
|
||||
"/events?access_token=%s&timeout=0" % (self.token))
|
||||
"/events?access_token=%s&timeout=0" % (self.token,)
|
||||
)
|
||||
self.assertEquals(200, code, msg=str(response))
|
||||
|
||||
self.assertEquals(0, len(response["chunk"]))
|
||||
@ -203,7 +212,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
||||
|
||||
# left to room (expect no content for room)
|
||||
|
||||
def test_stream_items(self):
|
||||
def TODO_test_stream_items(self):
|
||||
# new user, no content
|
||||
|
||||
# join room, expect 1 item (join)
|
||||
|
@ -18,9 +18,9 @@
|
||||
from tests import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from mock import Mock
|
||||
from mock import Mock, NonCallableMock
|
||||
|
||||
from ..utils import MockHttpResource
|
||||
from ..utils import MockHttpResource, MockKey
|
||||
|
||||
from synapse.api.errors import SynapseError, AuthError
|
||||
from synapse.server import HomeServer
|
||||
@ -41,6 +41,9 @@ class ProfileTestCase(unittest.TestCase):
|
||||
"set_avatar_url",
|
||||
])
|
||||
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer("test",
|
||||
db_pool=None,
|
||||
http_client=None,
|
||||
@ -48,6 +51,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||
federation=Mock(),
|
||||
replication_layer=Mock(),
|
||||
datastore=None,
|
||||
config=self.mock_config,
|
||||
)
|
||||
|
||||
def _get_user_by_req(request=None):
|
||||
|
@ -23,11 +23,14 @@ from synapse.api.constants import Membership
|
||||
|
||||
from synapse.server import HomeServer
|
||||
|
||||
from tests import unittest
|
||||
|
||||
# python imports
|
||||
import json
|
||||
import urllib
|
||||
import types
|
||||
|
||||
from ..utils import MockHttpResource, MemoryDataStore
|
||||
from ..utils import MockHttpResource, SQLiteMemoryDbPool, MockKey
|
||||
from .utils import RestTestCase
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
@ -44,24 +47,21 @@ class RoomPermissionsTestCase(RestTestCase):
|
||||
def setUp(self):
|
||||
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
|
||||
|
||||
state_handler = Mock(spec=["handle_new_event"])
|
||||
state_handler.handle_new_event.return_value = True
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||
db_pool = SQLiteMemoryDbPool()
|
||||
yield db_pool.prepare()
|
||||
|
||||
hs = HomeServer(
|
||||
"red",
|
||||
db_pool=None,
|
||||
db_pool=db_pool,
|
||||
http_client=None,
|
||||
datastore=MemoryDataStore(),
|
||||
replication_layer=Mock(),
|
||||
state_handler=state_handler,
|
||||
persistence_service=persistence_service,
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=NonCallableMock(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
@ -76,6 +76,10 @@ class RoomPermissionsTestCase(RestTestCase):
|
||||
}
|
||||
hs.get_auth().get_user_by_token = _get_user_by_token
|
||||
|
||||
def _insert_client_ip(*args, **kwargs):
|
||||
return defer.succeed(None)
|
||||
hs.get_datastore().insert_client_ip = _insert_client_ip
|
||||
|
||||
self.auth_user_id = self.rmcreator_id
|
||||
|
||||
synapse.rest.room.register_servlets(hs, self.mock_resource)
|
||||
@ -147,38 +151,55 @@ class RoomPermissionsTestCase(RestTestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_send_message(self):
|
||||
msg_content = '{"msgtype":"m.text","body":"hello"}'
|
||||
send_msg_path = ("/rooms/%s/send/m.room.message/mid1" %
|
||||
(self.created_rmid))
|
||||
send_msg_path = (
|
||||
"/rooms/%s/send/m.room.message/mid1" % (self.created_rmid,)
|
||||
)
|
||||
|
||||
# send message in uncreated room, expect 403
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"PUT",
|
||||
"/rooms/%s/send/m.room.message/mid2" %
|
||||
(self.uncreated_rmid), msg_content)
|
||||
"PUT",
|
||||
"/rooms/%s/send/m.room.message/mid2" % (self.uncreated_rmid,),
|
||||
msg_content
|
||||
)
|
||||
self.assertEquals(403, code, msg=str(response))
|
||||
|
||||
# send message in created room not joined (no state), expect 403
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"PUT", send_msg_path, msg_content)
|
||||
"PUT",
|
||||
send_msg_path,
|
||||
msg_content
|
||||
)
|
||||
self.assertEquals(403, code, msg=str(response))
|
||||
|
||||
# send message in created room and invited, expect 403
|
||||
yield self.invite(room=self.created_rmid, src=self.rmcreator_id,
|
||||
targ=self.user_id)
|
||||
yield self.invite(
|
||||
room=self.created_rmid,
|
||||
src=self.rmcreator_id,
|
||||
targ=self.user_id
|
||||
)
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"PUT", send_msg_path, msg_content)
|
||||
"PUT",
|
||||
send_msg_path,
|
||||
msg_content
|
||||
)
|
||||
self.assertEquals(403, code, msg=str(response))
|
||||
|
||||
# send message in created room and joined, expect 200
|
||||
yield self.join(room=self.created_rmid, user=self.user_id)
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"PUT", send_msg_path, msg_content)
|
||||
"PUT",
|
||||
send_msg_path,
|
||||
msg_content
|
||||
)
|
||||
self.assertEquals(200, code, msg=str(response))
|
||||
|
||||
# send message in created room and left, expect 403
|
||||
yield self.leave(room=self.created_rmid, user=self.user_id)
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"PUT", send_msg_path, msg_content)
|
||||
"PUT",
|
||||
send_msg_path,
|
||||
msg_content
|
||||
)
|
||||
self.assertEquals(403, code, msg=str(response))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -215,9 +236,14 @@ class RoomPermissionsTestCase(RestTestCase):
|
||||
|
||||
# set/get topic in created PRIVATE room and joined, expect 200
|
||||
yield self.join(room=self.created_rmid, user=self.user_id)
|
||||
|
||||
# Only room ops can set topic by default
|
||||
self.auth_user_id = self.rmcreator_id
|
||||
(code, response) = yield self.mock_resource.trigger(
|
||||
"PUT", topic_path, topic_content)
|
||||
self.assertEquals(200, code, msg=str(response))
|
||||
self.auth_user_id = self.user_id
|
||||
|
||||
(code, response) = yield self.mock_resource.trigger_get(topic_path)
|
||||
self.assertEquals(200, code, msg=str(response))
|
||||
self.assert_dict(json.loads(topic_content), response)
|
||||
@ -381,45 +407,55 @@ class RoomPermissionsTestCase(RestTestCase):
|
||||
# set [invite/join/left] of self, set [invite/join/left] of other,
|
||||
# expect all 403s
|
||||
for usr in [self.user_id, self.rmcreator_id]:
|
||||
yield self.change_membership(room=room, src=self.user_id,
|
||||
targ=usr,
|
||||
membership=Membership.INVITE,
|
||||
expect_code=403)
|
||||
yield self.change_membership(room=room, src=self.user_id,
|
||||
targ=usr,
|
||||
membership=Membership.JOIN,
|
||||
expect_code=403)
|
||||
yield self.change_membership(room=room, src=self.user_id,
|
||||
targ=usr,
|
||||
membership=Membership.LEAVE,
|
||||
expect_code=403)
|
||||
yield self.change_membership(
|
||||
room=room,
|
||||
src=self.user_id,
|
||||
targ=usr,
|
||||
membership=Membership.INVITE,
|
||||
expect_code=403
|
||||
)
|
||||
|
||||
yield self.change_membership(
|
||||
room=room,
|
||||
src=self.user_id,
|
||||
targ=usr,
|
||||
membership=Membership.JOIN,
|
||||
expect_code=403
|
||||
)
|
||||
|
||||
# It is always valid to LEAVE if you've already left (currently.)
|
||||
yield self.change_membership(
|
||||
room=room,
|
||||
src=self.user_id,
|
||||
targ=self.rmcreator_id,
|
||||
membership=Membership.LEAVE,
|
||||
expect_code=403
|
||||
)
|
||||
|
||||
|
||||
class RoomsMemberListTestCase(RestTestCase):
|
||||
""" Tests /rooms/$room_id/members/list REST events."""
|
||||
user_id = "@sid1:red"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
|
||||
|
||||
state_handler = Mock(spec=["handle_new_event"])
|
||||
state_handler.handle_new_event.return_value = True
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||
db_pool = SQLiteMemoryDbPool()
|
||||
yield db_pool.prepare()
|
||||
|
||||
hs = HomeServer(
|
||||
"red",
|
||||
db_pool=None,
|
||||
db_pool=db_pool,
|
||||
http_client=None,
|
||||
datastore=MemoryDataStore(),
|
||||
replication_layer=Mock(),
|
||||
state_handler=state_handler,
|
||||
persistence_service=persistence_service,
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=NonCallableMock(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
@ -436,6 +472,10 @@ class RoomsMemberListTestCase(RestTestCase):
|
||||
}
|
||||
hs.get_auth().get_user_by_token = _get_user_by_token
|
||||
|
||||
def _insert_client_ip(*args, **kwargs):
|
||||
return defer.succeed(None)
|
||||
hs.get_datastore().insert_client_ip = _insert_client_ip
|
||||
|
||||
synapse.rest.room.register_servlets(hs, self.mock_resource)
|
||||
|
||||
def tearDown(self):
|
||||
@ -487,28 +527,26 @@ class RoomsCreateTestCase(RestTestCase):
|
||||
""" Tests /rooms and /rooms/$room_id REST events. """
|
||||
user_id = "@sid1:red"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
|
||||
self.auth_user_id = self.user_id
|
||||
|
||||
state_handler = Mock(spec=["handle_new_event"])
|
||||
state_handler.handle_new_event.return_value = True
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||
db_pool = SQLiteMemoryDbPool()
|
||||
yield db_pool.prepare()
|
||||
|
||||
hs = HomeServer(
|
||||
"red",
|
||||
db_pool=None,
|
||||
db_pool=db_pool,
|
||||
http_client=None,
|
||||
datastore=MemoryDataStore(),
|
||||
replication_layer=Mock(),
|
||||
state_handler=state_handler,
|
||||
persistence_service=persistence_service,
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=NonCallableMock(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
@ -523,6 +561,10 @@ class RoomsCreateTestCase(RestTestCase):
|
||||
}
|
||||
hs.get_auth().get_user_by_token = _get_user_by_token
|
||||
|
||||
def _insert_client_ip(*args, **kwargs):
|
||||
return defer.succeed(None)
|
||||
hs.get_datastore().insert_client_ip = _insert_client_ip
|
||||
|
||||
synapse.rest.room.register_servlets(hs, self.mock_resource)
|
||||
|
||||
def tearDown(self):
|
||||
@ -592,24 +634,21 @@ class RoomTopicTestCase(RestTestCase):
|
||||
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
|
||||
self.auth_user_id = self.user_id
|
||||
|
||||
state_handler = Mock(spec=["handle_new_event"])
|
||||
state_handler.handle_new_event.return_value = True
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||
db_pool = SQLiteMemoryDbPool()
|
||||
yield db_pool.prepare()
|
||||
|
||||
hs = HomeServer(
|
||||
"red",
|
||||
db_pool=None,
|
||||
db_pool=db_pool,
|
||||
http_client=None,
|
||||
datastore=MemoryDataStore(),
|
||||
replication_layer=Mock(),
|
||||
state_handler=state_handler,
|
||||
persistence_service=persistence_service,
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=NonCallableMock(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
@ -622,13 +661,18 @@ class RoomTopicTestCase(RestTestCase):
|
||||
"admin": False,
|
||||
"device_id": None,
|
||||
}
|
||||
|
||||
hs.get_auth().get_user_by_token = _get_user_by_token
|
||||
|
||||
def _insert_client_ip(*args, **kwargs):
|
||||
return defer.succeed(None)
|
||||
hs.get_datastore().insert_client_ip = _insert_client_ip
|
||||
|
||||
synapse.rest.room.register_servlets(hs, self.mock_resource)
|
||||
|
||||
# create the room
|
||||
self.room_id = yield self.create_room_as(self.user_id)
|
||||
self.path = "/rooms/%s/state/m.room.topic" % self.room_id
|
||||
self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
@ -706,24 +750,21 @@ class RoomMemberStateTestCase(RestTestCase):
|
||||
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
|
||||
self.auth_user_id = self.user_id
|
||||
|
||||
state_handler = Mock(spec=["handle_new_event"])
|
||||
state_handler.handle_new_event.return_value = True
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||
db_pool = SQLiteMemoryDbPool()
|
||||
yield db_pool.prepare()
|
||||
|
||||
hs = HomeServer(
|
||||
"red",
|
||||
db_pool=None,
|
||||
db_pool=db_pool,
|
||||
http_client=None,
|
||||
datastore=MemoryDataStore(),
|
||||
replication_layer=Mock(),
|
||||
state_handler=state_handler,
|
||||
persistence_service=persistence_service,
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=NonCallableMock(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
@ -736,13 +777,12 @@ class RoomMemberStateTestCase(RestTestCase):
|
||||
"admin": False,
|
||||
"device_id": None,
|
||||
}
|
||||
return {
|
||||
"user": hs.parse_userid(self.auth_user_id),
|
||||
"admin": False,
|
||||
"device_id": None,
|
||||
}
|
||||
hs.get_auth().get_user_by_token = _get_user_by_token
|
||||
|
||||
def _insert_client_ip(*args, **kwargs):
|
||||
return defer.succeed(None)
|
||||
hs.get_datastore().insert_client_ip = _insert_client_ip
|
||||
|
||||
synapse.rest.room.register_servlets(hs, self.mock_resource)
|
||||
|
||||
self.room_id = yield self.create_room_as(self.user_id)
|
||||
@ -847,24 +887,21 @@ class RoomMessagesTestCase(RestTestCase):
|
||||
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
|
||||
self.auth_user_id = self.user_id
|
||||
|
||||
state_handler = Mock(spec=["handle_new_event"])
|
||||
state_handler.handle_new_event.return_value = True
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||
db_pool = SQLiteMemoryDbPool()
|
||||
yield db_pool.prepare()
|
||||
|
||||
hs = HomeServer(
|
||||
"red",
|
||||
db_pool=None,
|
||||
db_pool=db_pool,
|
||||
http_client=None,
|
||||
datastore=MemoryDataStore(),
|
||||
replication_layer=Mock(),
|
||||
state_handler=state_handler,
|
||||
persistence_service=persistence_service,
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=NonCallableMock(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
@ -879,6 +916,10 @@ class RoomMessagesTestCase(RestTestCase):
|
||||
}
|
||||
hs.get_auth().get_user_by_token = _get_user_by_token
|
||||
|
||||
def _insert_client_ip(*args, **kwargs):
|
||||
return defer.succeed(None)
|
||||
hs.get_datastore().insert_client_ip = _insert_client_ip
|
||||
|
||||
synapse.rest.room.register_servlets(hs, self.mock_resource)
|
||||
|
||||
self.room_id = yield self.create_room_as(self.user_id)
|
||||
|
@ -74,7 +74,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_select_one_1col(self):
|
||||
self.mock_txn.rowcount = 1
|
||||
self.mock_txn.fetchone.return_value = ("Value",)
|
||||
self.mock_txn.fetchall.return_value = [("Value",)]
|
||||
|
||||
value = yield self.datastore._simple_select_one_onecol(
|
||||
table="tablename",
|
||||
|
@ -61,6 +61,7 @@ class RedactionTestCase(unittest.TestCase):
|
||||
membership=membership,
|
||||
content={"membership": membership},
|
||||
depth=self.depth,
|
||||
prev_events=[],
|
||||
)
|
||||
|
||||
event.content.update(extra_content)
|
||||
@ -68,6 +69,11 @@ class RedactionTestCase(unittest.TestCase):
|
||||
if prev_state:
|
||||
event.prev_state = prev_state
|
||||
|
||||
event.state_events = None
|
||||
event.hashes = {}
|
||||
event.prev_state = []
|
||||
event.auth_events = []
|
||||
|
||||
# Have to create a join event using the eventfactory
|
||||
yield self.store.persist_event(
|
||||
event
|
||||
@ -85,8 +91,13 @@ class RedactionTestCase(unittest.TestCase):
|
||||
room_id=room.to_string(),
|
||||
content={"body": body, "msgtype": u"message"},
|
||||
depth=self.depth,
|
||||
prev_events=[],
|
||||
)
|
||||
|
||||
event.state_events = None
|
||||
event.hashes = {}
|
||||
event.auth_events = []
|
||||
|
||||
yield self.store.persist_event(
|
||||
event
|
||||
)
|
||||
@ -102,8 +113,13 @@ class RedactionTestCase(unittest.TestCase):
|
||||
content={"reason": reason},
|
||||
depth=self.depth,
|
||||
redacts=event_id,
|
||||
prev_events=[],
|
||||
)
|
||||
|
||||
event.state_events = None
|
||||
event.hashes = {}
|
||||
event.auth_events = []
|
||||
|
||||
yield self.store.persist_event(
|
||||
event
|
||||
)
|
||||
|
@ -127,7 +127,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_room_name(self):
|
||||
def STALE_test_room_name(self):
|
||||
name = u"A-Room-Name"
|
||||
|
||||
yield self.inject_room_event(
|
||||
@ -150,7 +150,7 @@ class RoomEventsStoreTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_room_name(self):
|
||||
def STALE_test_room_topic(self):
|
||||
topic = u"A place for things"
|
||||
|
||||
yield self.inject_room_event(
|
||||
|
@ -51,16 +51,24 @@ class RoomMemberStoreTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def inject_room_member(self, room, user, membership):
|
||||
# Have to create a join event using the eventfactory
|
||||
event = self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
user_id=user.to_string(),
|
||||
state_key=user.to_string(),
|
||||
room_id=room.to_string(),
|
||||
membership=membership,
|
||||
content={"membership": membership},
|
||||
depth=1,
|
||||
prev_events=[],
|
||||
)
|
||||
|
||||
event.state_events = None
|
||||
event.hashes = {}
|
||||
event.prev_state = {}
|
||||
event.auth_events = {}
|
||||
|
||||
yield self.store.persist_event(
|
||||
self.event_factory.create_event(
|
||||
etype=RoomMemberEvent.TYPE,
|
||||
user_id=user.to_string(),
|
||||
state_key=user.to_string(),
|
||||
room_id=room.to_string(),
|
||||
membership=membership,
|
||||
content={"membership": membership},
|
||||
depth=1,
|
||||
)
|
||||
event
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -48,7 +48,7 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||
self.depth = 1
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_room_member(self, room, user, membership, prev_state=None):
|
||||
def inject_room_member(self, room, user, membership, replaces_state=None):
|
||||
self.depth += 1
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
@ -59,10 +59,17 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||
membership=membership,
|
||||
content={"membership": membership},
|
||||
depth=self.depth,
|
||||
prev_events=[],
|
||||
)
|
||||
|
||||
if prev_state:
|
||||
event.prev_state = prev_state
|
||||
event.state_events = None
|
||||
event.hashes = {}
|
||||
event.prev_state = []
|
||||
event.auth_events = []
|
||||
|
||||
if replaces_state:
|
||||
event.prev_state = [(replaces_state, "hash")]
|
||||
event.replaces_state = replaces_state
|
||||
|
||||
# Have to create a join event using the eventfactory
|
||||
yield self.store.persist_event(
|
||||
@ -75,15 +82,22 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||
def inject_message(self, room, user, body):
|
||||
self.depth += 1
|
||||
|
||||
event = self.event_factory.create_event(
|
||||
etype=MessageEvent.TYPE,
|
||||
user_id=user.to_string(),
|
||||
room_id=room.to_string(),
|
||||
content={"body": body, "msgtype": u"message"},
|
||||
depth=self.depth,
|
||||
prev_events=[],
|
||||
)
|
||||
|
||||
event.state_events = None
|
||||
event.hashes = {}
|
||||
event.auth_events = []
|
||||
|
||||
# Have to create a join event using the eventfactory
|
||||
yield self.store.persist_event(
|
||||
self.event_factory.create_event(
|
||||
etype=MessageEvent.TYPE,
|
||||
user_id=user.to_string(),
|
||||
room_id=room.to_string(),
|
||||
content={"body": body, "msgtype": u"message"},
|
||||
depth=self.depth,
|
||||
)
|
||||
event
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@ -206,7 +220,7 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||
|
||||
event2 = yield self.inject_room_member(
|
||||
self.room1, self.u_alice, Membership.JOIN,
|
||||
prev_state=event1.event_id,
|
||||
replaces_state=event1.event_id,
|
||||
)
|
||||
|
||||
end = yield self.store.get_room_events_max_id()
|
||||
@ -223,4 +237,7 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||
|
||||
event = results[0]
|
||||
|
||||
self.assertTrue(hasattr(event, "prev_content"), msg="No prev_content key")
|
||||
self.assertTrue(
|
||||
hasattr(event, "prev_content"),
|
||||
msg="No prev_content key"
|
||||
)
|
||||
|
@ -15,599 +15,258 @@
|
||||
|
||||
from tests import unittest
|
||||
from twisted.internet import defer
|
||||
from twisted.python.log import PythonLoggingObserver
|
||||
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage.pdu import PduEntry
|
||||
from synapse.federation.pdu_codec import encode_event_id
|
||||
from synapse.federation.units import Pdu
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
from mock import Mock
|
||||
|
||||
import mock
|
||||
|
||||
|
||||
ReturnType = namedtuple(
|
||||
"StateReturnType", ["new_branch", "current_branch"]
|
||||
)
|
||||
|
||||
|
||||
def _gen_get_power_level(power_level_list):
|
||||
def get_power_level(room_id, user_id):
|
||||
return defer.succeed(power_level_list.get(user_id, None))
|
||||
return get_power_level
|
||||
|
||||
class StateTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.persistence = Mock(spec=[
|
||||
"get_unresolved_state_tree",
|
||||
"update_current_state",
|
||||
"get_latest_pdus_in_context",
|
||||
"get_current_state_pdu",
|
||||
"get_pdu",
|
||||
"get_power_level",
|
||||
])
|
||||
self.replication = Mock(spec=["get_pdu"])
|
||||
|
||||
hs = Mock(spec=["get_datastore", "get_replication_layer"])
|
||||
hs.get_datastore.return_value = self.persistence
|
||||
hs.get_replication_layer.return_value = self.replication
|
||||
hs.hostname = "bob.com"
|
||||
self.store = Mock(
|
||||
spec_set=[
|
||||
"get_state_groups",
|
||||
]
|
||||
)
|
||||
hs = Mock(spec=["get_datastore"])
|
||||
hs.get_datastore.return_value = self.store
|
||||
|
||||
self.state = StateHandler(hs)
|
||||
self.event_id = 0
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_new_state_key(self):
|
||||
# We've never seen anything for this state before
|
||||
new_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u")
|
||||
def test_annotate_with_old_message(self):
|
||||
event = self.create_event(type="test_message", name="event")
|
||||
|
||||
self.persistence.get_power_level.side_effect = _gen_get_power_level({})
|
||||
|
||||
self.persistence.get_unresolved_state_tree.return_value = (
|
||||
(ReturnType([new_pdu], []), None)
|
||||
)
|
||||
|
||||
is_new = yield self.state.handle_new_state(new_pdu)
|
||||
|
||||
self.assertTrue(is_new)
|
||||
|
||||
self.persistence.get_unresolved_state_tree.assert_called_once_with(
|
||||
new_pdu
|
||||
)
|
||||
|
||||
self.assertEqual(1, self.persistence.update_current_state.call_count)
|
||||
|
||||
self.assertFalse(self.replication.get_pdu.called)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_direct_overwrite(self):
|
||||
# We do a direct overwriting of the old state, i.e., the new state
|
||||
# points to the old state.
|
||||
|
||||
old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1")
|
||||
new_pdu = new_fake_pdu("B", "test", "mem", "x", "A", "u2")
|
||||
|
||||
self.persistence.get_power_level.side_effect = _gen_get_power_level({
|
||||
"u1": 10,
|
||||
"u2": 5,
|
||||
})
|
||||
|
||||
self.persistence.get_unresolved_state_tree.return_value = (
|
||||
(ReturnType([new_pdu, old_pdu], [old_pdu]), None)
|
||||
)
|
||||
|
||||
is_new = yield self.state.handle_new_state(new_pdu)
|
||||
|
||||
self.assertTrue(is_new)
|
||||
|
||||
self.persistence.get_unresolved_state_tree.assert_called_once_with(
|
||||
new_pdu
|
||||
)
|
||||
|
||||
self.assertEqual(1, self.persistence.update_current_state.call_count)
|
||||
|
||||
self.assertFalse(self.replication.get_pdu.called)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_overwrite(self):
|
||||
old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
|
||||
old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", "A", "u2")
|
||||
new_pdu = new_fake_pdu("C", "test", "mem", "x", "B", "u3")
|
||||
|
||||
self.persistence.get_power_level.side_effect = _gen_get_power_level({
|
||||
"u1": 10,
|
||||
"u2": 5,
|
||||
"u3": 0,
|
||||
})
|
||||
|
||||
self.persistence.get_unresolved_state_tree.return_value = (
|
||||
(ReturnType([new_pdu, old_pdu_2, old_pdu_1], [old_pdu_1]), None)
|
||||
)
|
||||
|
||||
is_new = yield self.state.handle_new_state(new_pdu)
|
||||
|
||||
self.assertTrue(is_new)
|
||||
|
||||
self.persistence.get_unresolved_state_tree.assert_called_once_with(
|
||||
new_pdu
|
||||
)
|
||||
|
||||
self.assertEqual(1, self.persistence.update_current_state.call_count)
|
||||
|
||||
self.assertFalse(self.replication.get_pdu.called)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_power_level_fail(self):
|
||||
# We try to update the state based on an outdated state, and have a
|
||||
# too low power level.
|
||||
|
||||
old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
|
||||
old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
|
||||
new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
|
||||
|
||||
self.persistence.get_power_level.side_effect = _gen_get_power_level({
|
||||
"u1": 10,
|
||||
"u2": 10,
|
||||
"u3": 5,
|
||||
})
|
||||
|
||||
self.persistence.get_unresolved_state_tree.return_value = (
|
||||
(ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)
|
||||
)
|
||||
|
||||
is_new = yield self.state.handle_new_state(new_pdu)
|
||||
|
||||
self.assertFalse(is_new)
|
||||
|
||||
self.persistence.get_unresolved_state_tree.assert_called_once_with(
|
||||
new_pdu
|
||||
)
|
||||
|
||||
self.assertEqual(0, self.persistence.update_current_state.call_count)
|
||||
|
||||
self.assertFalse(self.replication.get_pdu.called)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_power_level_succeed(self):
|
||||
# We try to update the state based on an outdated state, but have
|
||||
# sufficient power level to force the update.
|
||||
|
||||
old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
|
||||
old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
|
||||
new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
|
||||
|
||||
self.persistence.get_power_level.side_effect = _gen_get_power_level({
|
||||
"u1": 10,
|
||||
"u2": 10,
|
||||
"u3": 15,
|
||||
})
|
||||
|
||||
self.persistence.get_unresolved_state_tree.return_value = (
|
||||
(ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)
|
||||
)
|
||||
|
||||
is_new = yield self.state.handle_new_state(new_pdu)
|
||||
|
||||
self.assertTrue(is_new)
|
||||
|
||||
self.persistence.get_unresolved_state_tree.assert_called_once_with(
|
||||
new_pdu
|
||||
)
|
||||
|
||||
self.assertEqual(1, self.persistence.update_current_state.call_count)
|
||||
|
||||
self.assertFalse(self.replication.get_pdu.called)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_power_level_equal_same_len(self):
|
||||
# We try to update the state based on an outdated state, the power
|
||||
# levels are the same and so are the branch lengths
|
||||
|
||||
old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
|
||||
old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
|
||||
new_pdu = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
|
||||
|
||||
self.persistence.get_power_level.side_effect = _gen_get_power_level({
|
||||
"u1": 10,
|
||||
"u2": 10,
|
||||
"u3": 10,
|
||||
})
|
||||
|
||||
self.persistence.get_unresolved_state_tree.return_value = (
|
||||
(ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None)
|
||||
)
|
||||
|
||||
is_new = yield self.state.handle_new_state(new_pdu)
|
||||
|
||||
self.assertTrue(is_new)
|
||||
|
||||
self.persistence.get_unresolved_state_tree.assert_called_once_with(
|
||||
new_pdu
|
||||
)
|
||||
|
||||
self.assertEqual(1, self.persistence.update_current_state.call_count)
|
||||
|
||||
self.assertFalse(self.replication.get_pdu.called)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_power_level_equal_diff_len(self):
|
||||
# We try to update the state based on an outdated state, the power
|
||||
# levels are the same but the branch length of the new one is longer.
|
||||
|
||||
old_pdu_1 = new_fake_pdu("A", "test", "mem", "x", None, "u1")
|
||||
old_pdu_2 = new_fake_pdu("B", "test", "mem", "x", None, "u2")
|
||||
old_pdu_3 = new_fake_pdu("C", "test", "mem", "x", "A", "u3")
|
||||
new_pdu = new_fake_pdu("D", "test", "mem", "x", "C", "u4")
|
||||
|
||||
self.persistence.get_power_level.side_effect = _gen_get_power_level({
|
||||
"u1": 10,
|
||||
"u2": 10,
|
||||
"u3": 10,
|
||||
"u4": 10,
|
||||
})
|
||||
|
||||
self.persistence.get_unresolved_state_tree.return_value = (
|
||||
(
|
||||
ReturnType(
|
||||
[new_pdu, old_pdu_3, old_pdu_1],
|
||||
[old_pdu_2, old_pdu_1]
|
||||
),
|
||||
None
|
||||
)
|
||||
)
|
||||
|
||||
is_new = yield self.state.handle_new_state(new_pdu)
|
||||
|
||||
self.assertTrue(is_new)
|
||||
|
||||
self.persistence.get_unresolved_state_tree.assert_called_once_with(
|
||||
new_pdu
|
||||
)
|
||||
|
||||
self.assertEqual(1, self.persistence.update_current_state.call_count)
|
||||
|
||||
self.assertFalse(self.replication.get_pdu.called)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_missing_pdu(self):
|
||||
# We try to update state against a PDU we haven't yet seen,
|
||||
# triggering a get_pdu request
|
||||
|
||||
# The pdu we haven't seen
|
||||
old_pdu_1 = new_fake_pdu(
|
||||
"A", "test", "mem", "x", None, "u1", depth=0
|
||||
)
|
||||
|
||||
old_pdu_2 = new_fake_pdu(
|
||||
"B", "test", "mem", "x", "A", "u2", depth=1
|
||||
)
|
||||
new_pdu = new_fake_pdu(
|
||||
"C", "test", "mem", "x", "A", "u3", depth=2
|
||||
)
|
||||
|
||||
self.persistence.get_power_level.side_effect = _gen_get_power_level({
|
||||
"u1": 10,
|
||||
"u2": 10,
|
||||
"u3": 20,
|
||||
})
|
||||
|
||||
# The return_value of `get_unresolved_state_tree`, which changes after
|
||||
# the call to get_pdu
|
||||
tree_to_return = [(ReturnType([new_pdu], [old_pdu_2]), 0)]
|
||||
|
||||
def return_tree(p):
|
||||
return tree_to_return[0]
|
||||
|
||||
def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
|
||||
tree_to_return[0] = (
|
||||
ReturnType(
|
||||
[new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]
|
||||
),
|
||||
None
|
||||
)
|
||||
return defer.succeed(None)
|
||||
|
||||
self.persistence.get_unresolved_state_tree.side_effect = return_tree
|
||||
|
||||
self.replication.get_pdu.side_effect = set_return_tree
|
||||
|
||||
self.persistence.get_pdu.return_value = None
|
||||
|
||||
is_new = yield self.state.handle_new_state(new_pdu)
|
||||
|
||||
self.assertTrue(is_new)
|
||||
|
||||
self.replication.get_pdu.assert_called_with(
|
||||
destination=new_pdu.origin,
|
||||
pdu_origin=old_pdu_1.origin,
|
||||
pdu_id=old_pdu_1.pdu_id,
|
||||
outlier=True
|
||||
)
|
||||
|
||||
self.persistence.get_unresolved_state_tree.assert_called_with(
|
||||
new_pdu
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
2, self.persistence.get_unresolved_state_tree.call_count
|
||||
)
|
||||
|
||||
self.assertEqual(1, self.persistence.update_current_state.call_count)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_missing_pdu_depth_1(self):
|
||||
# We try to update state against a PDU we haven't yet seen,
|
||||
# triggering a get_pdu request
|
||||
|
||||
# The pdu we haven't seen
|
||||
old_pdu_1 = new_fake_pdu(
|
||||
"A", "test", "mem", "x", None, "u1", depth=0
|
||||
)
|
||||
|
||||
old_pdu_2 = new_fake_pdu(
|
||||
"B", "test", "mem", "x", "A", "u2", depth=2
|
||||
)
|
||||
old_pdu_3 = new_fake_pdu(
|
||||
"C", "test", "mem", "x", "B", "u3", depth=3
|
||||
)
|
||||
new_pdu = new_fake_pdu(
|
||||
"D", "test", "mem", "x", "A", "u4", depth=4
|
||||
)
|
||||
|
||||
self.persistence.get_power_level.side_effect = _gen_get_power_level({
|
||||
"u1": 10,
|
||||
"u2": 10,
|
||||
"u3": 10,
|
||||
"u4": 20,
|
||||
})
|
||||
|
||||
# The return_value of `get_unresolved_state_tree`, which changes after
|
||||
# the call to get_pdu
|
||||
tree_to_return = [
|
||||
(
|
||||
ReturnType([new_pdu], [old_pdu_3]),
|
||||
0
|
||||
),
|
||||
(
|
||||
ReturnType(
|
||||
[new_pdu, old_pdu_1], [old_pdu_3]
|
||||
),
|
||||
1
|
||||
),
|
||||
(
|
||||
ReturnType(
|
||||
[new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1]
|
||||
),
|
||||
None
|
||||
),
|
||||
old_state = [
|
||||
self.create_event(type="test1", state_key="1"),
|
||||
self.create_event(type="test1", state_key="2"),
|
||||
self.create_event(type="test2", state_key=""),
|
||||
]
|
||||
|
||||
to_return = [0]
|
||||
yield self.state.annotate_state_groups(event, old_state=old_state)
|
||||
|
||||
def return_tree(p):
|
||||
return tree_to_return[to_return[0]]
|
||||
for k, v in event.old_state_events.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
|
||||
to_return[0] += 1
|
||||
return defer.succeed(None)
|
||||
self.assertEqual(set(old_state), set(event.old_state_events.values()))
|
||||
self.assertDictEqual(event.old_state_events, event.state_events)
|
||||
|
||||
self.persistence.get_unresolved_state_tree.side_effect = return_tree
|
||||
|
||||
self.replication.get_pdu.side_effect = set_return_tree
|
||||
|
||||
self.persistence.get_pdu.return_value = None
|
||||
|
||||
is_new = yield self.state.handle_new_state(new_pdu)
|
||||
|
||||
self.assertTrue(is_new)
|
||||
|
||||
self.assertEqual(2, self.replication.get_pdu.call_count)
|
||||
|
||||
self.replication.get_pdu.assert_has_calls(
|
||||
[
|
||||
mock.call(
|
||||
destination=new_pdu.origin,
|
||||
pdu_origin=old_pdu_1.origin,
|
||||
pdu_id=old_pdu_1.pdu_id,
|
||||
outlier=True
|
||||
),
|
||||
mock.call(
|
||||
destination=old_pdu_3.origin,
|
||||
pdu_origin=old_pdu_2.origin,
|
||||
pdu_id=old_pdu_2.pdu_id,
|
||||
outlier=True
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.persistence.get_unresolved_state_tree.assert_called_with(
|
||||
new_pdu
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
3, self.persistence.get_unresolved_state_tree.call_count
|
||||
)
|
||||
|
||||
self.assertEqual(1, self.persistence.update_current_state.call_count)
|
||||
self.assertIsNone(event.state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_missing_pdu_depth_2(self):
|
||||
# We try to update state against a PDU we haven't yet seen,
|
||||
# triggering a get_pdu request
|
||||
def test_annotate_with_old_state(self):
|
||||
event = self.create_event(type="state", state_key="", name="event")
|
||||
|
||||
# The pdu we haven't seen
|
||||
old_pdu_1 = new_fake_pdu(
|
||||
"A", "test", "mem", "x", None, "u1", depth=0
|
||||
)
|
||||
|
||||
old_pdu_2 = new_fake_pdu(
|
||||
"B", "test", "mem", "x", "A", "u2", depth=2
|
||||
)
|
||||
old_pdu_3 = new_fake_pdu(
|
||||
"C", "test", "mem", "x", "B", "u3", depth=3
|
||||
)
|
||||
new_pdu = new_fake_pdu(
|
||||
"D", "test", "mem", "x", "A", "u4", depth=1
|
||||
)
|
||||
|
||||
self.persistence.get_power_level.side_effect = _gen_get_power_level({
|
||||
"u1": 10,
|
||||
"u2": 10,
|
||||
"u3": 10,
|
||||
"u4": 20,
|
||||
})
|
||||
|
||||
# The return_value of `get_unresolved_state_tree`, which changes after
|
||||
# the call to get_pdu
|
||||
tree_to_return = [
|
||||
(
|
||||
ReturnType([new_pdu], [old_pdu_3]),
|
||||
1,
|
||||
),
|
||||
(
|
||||
ReturnType(
|
||||
[new_pdu], [old_pdu_3, old_pdu_2]
|
||||
),
|
||||
0,
|
||||
),
|
||||
(
|
||||
ReturnType(
|
||||
[new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1]
|
||||
),
|
||||
None
|
||||
),
|
||||
old_state = [
|
||||
self.create_event(type="test1", state_key="1"),
|
||||
self.create_event(type="test1", state_key="2"),
|
||||
self.create_event(type="test2", state_key=""),
|
||||
]
|
||||
|
||||
to_return = [0]
|
||||
yield self.state.annotate_state_groups(event, old_state=old_state)
|
||||
|
||||
def return_tree(p):
|
||||
return tree_to_return[to_return[0]]
|
||||
|
||||
def set_return_tree(destination, pdu_origin, pdu_id, outlier=False):
|
||||
to_return[0] += 1
|
||||
return defer.succeed(None)
|
||||
|
||||
self.persistence.get_unresolved_state_tree.side_effect = return_tree
|
||||
|
||||
self.replication.get_pdu.side_effect = set_return_tree
|
||||
|
||||
self.persistence.get_pdu.return_value = None
|
||||
|
||||
is_new = yield self.state.handle_new_state(new_pdu)
|
||||
|
||||
self.assertTrue(is_new)
|
||||
|
||||
self.assertEqual(2, self.replication.get_pdu.call_count)
|
||||
|
||||
self.replication.get_pdu.assert_has_calls(
|
||||
[
|
||||
mock.call(
|
||||
destination=old_pdu_3.origin,
|
||||
pdu_origin=old_pdu_2.origin,
|
||||
pdu_id=old_pdu_2.pdu_id,
|
||||
outlier=True
|
||||
),
|
||||
mock.call(
|
||||
destination=new_pdu.origin,
|
||||
pdu_origin=old_pdu_1.origin,
|
||||
pdu_id=old_pdu_1.pdu_id,
|
||||
outlier=True
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.persistence.get_unresolved_state_tree.assert_called_with(
|
||||
new_pdu
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
3, self.persistence.get_unresolved_state_tree.call_count
|
||||
)
|
||||
|
||||
self.assertEqual(1, self.persistence.update_current_state.call_count)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_no_common_ancestor(self):
|
||||
# We do a direct overwriting of the old state, i.e., the new state
|
||||
# points to the old state.
|
||||
|
||||
old_pdu = new_fake_pdu("A", "test", "mem", "x", None, "u1")
|
||||
new_pdu = new_fake_pdu("B", "test", "mem", "x", None, "u2")
|
||||
|
||||
self.persistence.get_power_level.side_effect = _gen_get_power_level({
|
||||
"u1": 5,
|
||||
"u2": 10,
|
||||
})
|
||||
|
||||
self.persistence.get_unresolved_state_tree.return_value = (
|
||||
(ReturnType([new_pdu], [old_pdu]), None)
|
||||
)
|
||||
|
||||
is_new = yield self.state.handle_new_state(new_pdu)
|
||||
|
||||
self.assertTrue(is_new)
|
||||
|
||||
self.persistence.get_unresolved_state_tree.assert_called_once_with(
|
||||
new_pdu
|
||||
)
|
||||
|
||||
self.assertEqual(1, self.persistence.update_current_state.call_count)
|
||||
|
||||
self.assertFalse(self.replication.get_pdu.called)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_new_event(self):
|
||||
event = Mock()
|
||||
event.event_id = "12123123@test"
|
||||
|
||||
state_pdu = new_fake_pdu("C", "test", "mem", "x", "A", 20)
|
||||
|
||||
snapshot = Mock()
|
||||
snapshot.prev_state_pdu = state_pdu
|
||||
event_id = "pdu_id@origin.com"
|
||||
|
||||
def fill_out_prev_events(event):
|
||||
event.prev_events = [event_id]
|
||||
event.depth = 6
|
||||
snapshot.fill_out_prev_events = fill_out_prev_events
|
||||
|
||||
yield self.state.handle_new_event(event, snapshot)
|
||||
|
||||
self.assertLess(5, event.depth)
|
||||
|
||||
self.assertEquals(1, len(event.prev_events))
|
||||
|
||||
prev_id = event.prev_events[0]
|
||||
|
||||
self.assertEqual(event_id, prev_id)
|
||||
for k, v in event.old_state_events.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
encode_event_id(state_pdu.pdu_id, state_pdu.origin),
|
||||
event.prev_state
|
||||
set(old_state + [event]),
|
||||
set(event.old_state_events.values())
|
||||
)
|
||||
|
||||
self.assertDictEqual(event.old_state_events, event.state_events)
|
||||
|
||||
def new_fake_pdu(pdu_id, context, pdu_type, state_key, prev_state_id,
|
||||
user_id, depth=0):
|
||||
new_pdu = Pdu(
|
||||
pdu_id=pdu_id,
|
||||
pdu_type=pdu_type,
|
||||
state_key=state_key,
|
||||
user_id=user_id,
|
||||
prev_state_id=prev_state_id,
|
||||
origin="example.com",
|
||||
context="context",
|
||||
origin_server_ts=1405353060021,
|
||||
depth=depth,
|
||||
content_json="{}",
|
||||
unrecognized_keys="{}",
|
||||
outlier=True,
|
||||
is_state=True,
|
||||
prev_state_origin="example.com",
|
||||
have_processed=True,
|
||||
content={},
|
||||
)
|
||||
self.assertIsNone(event.state_group)
|
||||
|
||||
return new_pdu
|
||||
@defer.inlineCallbacks
|
||||
def test_trivial_annotate_message(self):
|
||||
event = self.create_event(type="test_message", name="event")
|
||||
event.prev_events = []
|
||||
|
||||
old_state = [
|
||||
self.create_event(type="test1", state_key="1"),
|
||||
self.create_event(type="test1", state_key="2"),
|
||||
self.create_event(type="test2", state_key=""),
|
||||
]
|
||||
|
||||
group_name = "group_name_1"
|
||||
|
||||
self.store.get_state_groups.return_value = {
|
||||
group_name: old_state,
|
||||
}
|
||||
|
||||
yield self.state.annotate_state_groups(event)
|
||||
|
||||
for k, v in event.old_state_events.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in old_state]),
|
||||
set([e.event_id for e in event.old_state_events.values()])
|
||||
)
|
||||
|
||||
self.assertDictEqual(
|
||||
{
|
||||
k: v.event_id
|
||||
for k, v in event.old_state_events.items()
|
||||
},
|
||||
{
|
||||
k: v.event_id
|
||||
for k, v in event.state_events.items()
|
||||
}
|
||||
)
|
||||
|
||||
self.assertEqual(group_name, event.state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_trivial_annotate_state(self):
|
||||
event = self.create_event(type="state", state_key="", name="event")
|
||||
event.prev_events = []
|
||||
|
||||
old_state = [
|
||||
self.create_event(type="test1", state_key="1"),
|
||||
self.create_event(type="test1", state_key="2"),
|
||||
self.create_event(type="test2", state_key=""),
|
||||
]
|
||||
|
||||
group_name = "group_name_1"
|
||||
|
||||
self.store.get_state_groups.return_value = {
|
||||
group_name: old_state,
|
||||
}
|
||||
|
||||
yield self.state.annotate_state_groups(event)
|
||||
|
||||
for k, v in event.old_state_events.items():
|
||||
type, state_key = k
|
||||
self.assertEqual(type, v.type)
|
||||
self.assertEqual(state_key, v.state_key)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in old_state]),
|
||||
set([e.event_id for e in event.old_state_events.values()])
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in old_state] + [event.event_id]),
|
||||
set([e.event_id for e in event.state_events.values()])
|
||||
)
|
||||
|
||||
new_state = {
|
||||
k: v.event_id
|
||||
for k, v in event.state_events.items()
|
||||
}
|
||||
old_state = {
|
||||
k: v.event_id
|
||||
for k, v in event.old_state_events.items()
|
||||
}
|
||||
old_state[(event.type, event.state_key)] = event.event_id
|
||||
self.assertDictEqual(
|
||||
old_state,
|
||||
new_state
|
||||
)
|
||||
|
||||
self.assertIsNone(event.state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_resolve_message_conflict(self):
|
||||
event = self.create_event(type="test_message", name="event")
|
||||
event.prev_events = []
|
||||
|
||||
old_state_1 = [
|
||||
self.create_event(type="test1", state_key="1"),
|
||||
self.create_event(type="test1", state_key="2"),
|
||||
self.create_event(type="test2", state_key=""),
|
||||
]
|
||||
|
||||
old_state_2 = [
|
||||
self.create_event(type="test1", state_key="1"),
|
||||
self.create_event(type="test3", state_key="2"),
|
||||
self.create_event(type="test4", state_key=""),
|
||||
]
|
||||
|
||||
group_name_1 = "group_name_1"
|
||||
group_name_2 = "group_name_2"
|
||||
|
||||
self.store.get_state_groups.return_value = {
|
||||
group_name_1: old_state_1,
|
||||
group_name_2: old_state_2,
|
||||
}
|
||||
|
||||
yield self.state.annotate_state_groups(event)
|
||||
|
||||
self.assertEqual(len(event.old_state_events), 5)
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in event.state_events.values()]),
|
||||
set([e.event_id for e in event.old_state_events.values()])
|
||||
)
|
||||
|
||||
self.assertIsNone(event.state_group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_resolve_state_conflict(self):
|
||||
event = self.create_event(type="test4", state_key="", name="event")
|
||||
event.prev_events = []
|
||||
|
||||
old_state_1 = [
|
||||
self.create_event(type="test1", state_key="1"),
|
||||
self.create_event(type="test1", state_key="2"),
|
||||
self.create_event(type="test2", state_key=""),
|
||||
]
|
||||
|
||||
old_state_2 = [
|
||||
self.create_event(type="test1", state_key="1"),
|
||||
self.create_event(type="test3", state_key="2"),
|
||||
self.create_event(type="test4", state_key=""),
|
||||
]
|
||||
|
||||
group_name_1 = "group_name_1"
|
||||
group_name_2 = "group_name_2"
|
||||
|
||||
self.store.get_state_groups.return_value = {
|
||||
group_name_1: old_state_1,
|
||||
group_name_2: old_state_2,
|
||||
}
|
||||
|
||||
yield self.state.annotate_state_groups(event)
|
||||
|
||||
self.assertEqual(len(event.old_state_events), 5)
|
||||
|
||||
expected_new = event.old_state_events
|
||||
expected_new[(event.type, event.state_key)] = event
|
||||
|
||||
self.assertEqual(
|
||||
set([e.event_id for e in expected_new.values()]),
|
||||
set([e.event_id for e in event.state_events.values()]),
|
||||
)
|
||||
|
||||
self.assertIsNone(event.state_group)
|
||||
|
||||
def create_event(self, name=None, type=None, state_key=None):
|
||||
self.event_id += 1
|
||||
event_id = str(self.event_id)
|
||||
|
||||
if not name:
|
||||
if state_key is not None:
|
||||
name = "<%s-%s>" % (type, state_key)
|
||||
else:
|
||||
name = "<%s>" % (type, )
|
||||
|
||||
event = Mock(name=name, spec=[])
|
||||
event.type = type
|
||||
|
||||
if state_key is not None:
|
||||
event.state_key = state_key
|
||||
event.event_id = event_id
|
||||
|
||||
event.user_id = "@user_id:example.com"
|
||||
event.room_id = "!room_id:example.com"
|
||||
|
||||
return event
|
||||
|
@ -118,13 +118,14 @@ class MockHttpResource(HttpServer):
|
||||
class MockKey(object):
|
||||
alg = "mock_alg"
|
||||
version = "mock_version"
|
||||
signature = b"\x9a\x87$"
|
||||
|
||||
@property
|
||||
def verify_key(self):
|
||||
return self
|
||||
|
||||
def sign(self, message):
|
||||
return b"\x9a\x87$"
|
||||
return self
|
||||
|
||||
def verify(self, message, sig):
|
||||
assert sig == b"\x9a\x87$"
|
||||
|
Loading…
Reference in New Issue
Block a user