mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-12-26 02:59:22 -05:00
Merge branch 'event_signing' of github.com:matrix-org/synapse into federation_authorization
Conflicts: synapse/storage/__init__.py
This commit is contained in:
commit
ad9226eeec
@ -1,13 +1,13 @@
|
||||
Signing JSON
|
||||
============
|
||||
|
||||
JSON is signed by encoding the JSON object without ``signatures`` or ``meta``
|
||||
JSON is signed by encoding the JSON object without ``signatures`` or ``unsigned``
|
||||
keys using a canonical encoding. The JSON bytes are then signed using the
|
||||
signature algorithm and the signature encoded using base64 with the padding
|
||||
stripped. The resulting base64 signature is added to an object under the
|
||||
*signing key identifier* which is added to the ``signatures`` object under the
|
||||
name of the server signing it which is added back to the original JSON object
|
||||
along with the ``meta`` object.
|
||||
along with the ``unsigned`` object.
|
||||
|
||||
The *signing key identifier* is the concatenation of the *signing algorithm*
|
||||
and a *key version*. The *signing algorithm* identifies the algorithm used to
|
||||
@ -15,8 +15,8 @@ sign the JSON. The currently support value for *signing algorithm* is
|
||||
``ed25519`` as implemented by NACL (http://nacl.cr.yp.to/). The *key version*
|
||||
is used to distinguish between different signing keys used by the same entity.
|
||||
|
||||
The ``meta`` object and the ``signatures`` object are not covered by the
|
||||
signature. Therefore intermediate servers can add metadata such as time stamps
|
||||
The ``unsigned`` object and the ``signatures`` object are not covered by the
|
||||
signature. Therefore intermediate servers can add unsigneddata such as time stamps
|
||||
and additional signatures.
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ and additional signatures.
|
||||
"signing_keys": {
|
||||
"ed25519:1": "XSl0kuyvrXNj6A+7/tkrB9sxSbRi08Of5uRhxOqZtEQ"
|
||||
},
|
||||
"meta": {
|
||||
"unsigned": {
|
||||
"retrieved_ts_ms": 922834800000
|
||||
},
|
||||
"signatures": {
|
||||
@ -41,7 +41,7 @@ and additional signatures.
|
||||
|
||||
def sign_json(json_object, signing_key, signing_name):
|
||||
signatures = json_object.pop("signatures", {})
|
||||
meta = json_object.pop("meta", None)
|
||||
unsigned = json_object.pop("unsigned", None)
|
||||
|
||||
signed = signing_key.sign(encode_canonical_json(json_object))
|
||||
signature_base64 = encode_base64(signed.signature)
|
||||
@ -50,8 +50,8 @@ and additional signatures.
|
||||
signatures.setdefault(sigature_name, {})[key_id] = signature_base64
|
||||
|
||||
json_object["signatures"] = signatures
|
||||
if meta is not None:
|
||||
json_object["meta"] = meta
|
||||
if unsigned is not None:
|
||||
json_object["unsigned"] = unsigned
|
||||
|
||||
return json_object
|
||||
|
||||
|
69
scripts/hash_history.py
Normal file
69
scripts/hash_history.py
Normal file
@ -0,0 +1,69 @@
|
||||
from synapse.storage.pdu import PduStore
|
||||
from synapse.storage.signatures import SignatureStore
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.federation.units import Pdu
|
||||
from synapse.crypto.event_signing import (
|
||||
add_event_pdu_content_hash, compute_pdu_event_reference_hash
|
||||
)
|
||||
from synapse.api.events.utils import prune_pdu
|
||||
from syutil.base64util import encode_base64, decode_base64
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
import sqlite3
|
||||
import sys
|
||||
|
||||
class Store(object):
|
||||
_get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"]
|
||||
_get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"]
|
||||
_get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"]
|
||||
_get_pdu_origin_signatures_txn = SignatureStore.__dict__["_get_pdu_origin_signatures_txn"]
|
||||
_store_pdu_content_hash_txn = SignatureStore.__dict__["_store_pdu_content_hash_txn"]
|
||||
_store_pdu_reference_hash_txn = SignatureStore.__dict__["_store_pdu_reference_hash_txn"]
|
||||
_store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"]
|
||||
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
|
||||
|
||||
|
||||
store = Store()
|
||||
|
||||
|
||||
def select_pdus(cursor):
|
||||
cursor.execute(
|
||||
"SELECT pdu_id, origin FROM pdus ORDER BY depth ASC"
|
||||
)
|
||||
|
||||
ids = cursor.fetchall()
|
||||
|
||||
pdu_tuples = store._get_pdu_tuples(cursor, ids)
|
||||
|
||||
pdus = [Pdu.from_pdu_tuple(p) for p in pdu_tuples]
|
||||
|
||||
reference_hashes = {}
|
||||
|
||||
for pdu in pdus:
|
||||
try:
|
||||
if pdu.prev_pdus:
|
||||
print "PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus
|
||||
for pdu_id, origin, hashes in pdu.prev_pdus:
|
||||
ref_alg, ref_hsh = reference_hashes[(pdu_id, origin)]
|
||||
hashes[ref_alg] = encode_base64(ref_hsh)
|
||||
store._store_prev_pdu_hash_txn(cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh)
|
||||
print "SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus
|
||||
pdu = add_event_pdu_content_hash(pdu)
|
||||
ref_alg, ref_hsh = compute_pdu_event_reference_hash(pdu)
|
||||
reference_hashes[(pdu.pdu_id, pdu.origin)] = (ref_alg, ref_hsh)
|
||||
store._store_pdu_reference_hash_txn(cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh)
|
||||
|
||||
for alg, hsh_base64 in pdu.hashes.items():
|
||||
print alg, hsh_base64
|
||||
store._store_pdu_content_hash_txn(cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64))
|
||||
|
||||
except:
|
||||
print "FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus
|
||||
|
||||
def main():
|
||||
conn = sqlite3.connect(sys.argv[1])
|
||||
cursor = conn.cursor()
|
||||
select_pdus(cursor)
|
||||
conn.commit()
|
||||
|
||||
if __name__=='__main__':
|
||||
main()
|
@ -65,13 +65,13 @@ class SynapseEvent(JsonEncodedObject):
|
||||
|
||||
internal_keys = [
|
||||
"is_state",
|
||||
"prev_events",
|
||||
"depth",
|
||||
"destinations",
|
||||
"origin",
|
||||
"outlier",
|
||||
"power_level",
|
||||
"redacted",
|
||||
"prev_pdus",
|
||||
]
|
||||
|
||||
required_keys = [
|
||||
|
@ -27,7 +27,14 @@ def prune_event(event):
|
||||
the user has specified, but we do want to keep necessary information like
|
||||
type, state_key etc.
|
||||
"""
|
||||
return _prune_event_or_pdu(event.type, event)
|
||||
|
||||
def prune_pdu(pdu):
|
||||
"""Removes keys that contain unrestricted and non-essential data from a PDU
|
||||
"""
|
||||
return _prune_event_or_pdu(pdu.pdu_type, pdu)
|
||||
|
||||
def _prune_event_or_pdu(event_type, event):
|
||||
# Remove all extraneous fields.
|
||||
event.unrecognized_keys = {}
|
||||
|
||||
@ -38,25 +45,25 @@ def prune_event(event):
|
||||
if field in event.content:
|
||||
new_content[field] = event.content[field]
|
||||
|
||||
if event.type == RoomMemberEvent.TYPE:
|
||||
if event_type == RoomMemberEvent.TYPE:
|
||||
add_fields("membership")
|
||||
elif event.type == RoomCreateEvent.TYPE:
|
||||
elif event_type == RoomCreateEvent.TYPE:
|
||||
add_fields("creator")
|
||||
elif event.type == RoomJoinRulesEvent.TYPE:
|
||||
elif event_type == RoomJoinRulesEvent.TYPE:
|
||||
add_fields("join_rule")
|
||||
elif event.type == RoomPowerLevelsEvent.TYPE:
|
||||
elif event_type == RoomPowerLevelsEvent.TYPE:
|
||||
# TODO: Actually check these are valid user_ids etc.
|
||||
add_fields("default")
|
||||
for k, v in event.content.items():
|
||||
if k.startswith("@") and isinstance(v, (int, long)):
|
||||
new_content[k] = v
|
||||
elif event.type == RoomAddStateLevelEvent.TYPE:
|
||||
elif event_type == RoomAddStateLevelEvent.TYPE:
|
||||
add_fields("level")
|
||||
elif event.type == RoomSendEventLevelEvent.TYPE:
|
||||
elif event_type == RoomSendEventLevelEvent.TYPE:
|
||||
add_fields("level")
|
||||
elif event.type == RoomOpsPowerLevelsEvent.TYPE:
|
||||
elif event_type == RoomOpsPowerLevelsEvent.TYPE:
|
||||
add_fields("kick_level", "ban_level", "redact_level")
|
||||
elif event.type == RoomAliasesEvent.TYPE:
|
||||
elif event_type == RoomAliasesEvent.TYPE:
|
||||
add_fields("aliases")
|
||||
|
||||
event.content = new_content
|
||||
|
@ -74,7 +74,7 @@ class ServerConfig(Config):
|
||||
return syutil.crypto.signing_key.read_signing_keys(
|
||||
signing_keys.splitlines(True)
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ConfigError(
|
||||
"Error reading signing_key."
|
||||
" Try running again with --generate-config"
|
||||
|
85
synapse/crypto/event_signing.py
Normal file
85
synapse/crypto/event_signing.py
Normal file
@ -0,0 +1,85 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from synapse.federation.units import Pdu
|
||||
from synapse.api.events.utils import prune_pdu
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
from syutil.base64util import encode_base64, decode_base64
|
||||
from syutil.crypto.jsonsign import sign_json, verify_signed_json
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def add_event_pdu_content_hash(pdu, hash_algorithm=hashlib.sha256):
|
||||
hashed = _compute_content_hash(pdu, hash_algorithm)
|
||||
pdu.hashes[hashed.name] = encode_base64(hashed.digest())
|
||||
return pdu
|
||||
|
||||
|
||||
def check_event_pdu_content_hash(pdu, hash_algorithm=hashlib.sha256):
|
||||
"""Check whether the hash for this PDU matches the contents"""
|
||||
computed_hash = _compute_content_hash(pdu, hash_algorithm)
|
||||
if computed_hash.name not in pdu.hashes:
|
||||
raise Exception("Algorithm %s not in hashes %s" % (
|
||||
computed_hash.name, list(pdu.hashes)
|
||||
))
|
||||
message_hash_base64 = pdu.hashes[computed_hash.name]
|
||||
try:
|
||||
message_hash_bytes = decode_base64(message_hash_base64)
|
||||
except:
|
||||
raise Exception("Invalid base64: %s" % (message_hash_base64,))
|
||||
return message_hash_bytes == computed_hash.digest()
|
||||
|
||||
|
||||
def _compute_content_hash(pdu, hash_algorithm):
|
||||
pdu_json = pdu.get_dict()
|
||||
#TODO: Make "age_ts" key internal
|
||||
pdu_json.pop("age_ts", None)
|
||||
pdu_json.pop("unsigned", None)
|
||||
pdu_json.pop("signatures", None)
|
||||
pdu_json.pop("hashes", None)
|
||||
pdu_json_bytes = encode_canonical_json(pdu_json)
|
||||
return hash_algorithm(pdu_json_bytes)
|
||||
|
||||
|
||||
def compute_pdu_event_reference_hash(pdu, hash_algorithm=hashlib.sha256):
|
||||
tmp_pdu = Pdu(**pdu.get_dict())
|
||||
tmp_pdu = prune_pdu(tmp_pdu)
|
||||
pdu_json = tmp_pdu.get_dict()
|
||||
pdu_json.pop("signatures", None)
|
||||
pdu_json_bytes = encode_canonical_json(pdu_json)
|
||||
hashed = hash_algorithm(pdu_json_bytes)
|
||||
return (hashed.name, hashed.digest())
|
||||
|
||||
|
||||
def sign_event_pdu(pdu, signature_name, signing_key):
|
||||
tmp_pdu = Pdu(**pdu.get_dict())
|
||||
tmp_pdu = prune_pdu(tmp_pdu)
|
||||
pdu_json = tmp_pdu.get_dict()
|
||||
pdu_json = sign_json(pdu_json, signature_name, signing_key)
|
||||
pdu.signatures = pdu_json["signatures"]
|
||||
return pdu
|
||||
|
||||
|
||||
def verify_signed_event_pdu(pdu, signature_name, verify_key):
|
||||
tmp_pdu = Pdu(**pdu.get_dict())
|
||||
tmp_pdu = prune_pdu(tmp_pdu)
|
||||
pdu_json = tmp_pdu.get_dict()
|
||||
verify_signed_json(pdu_json, signature_name, verify_key)
|
@ -17,7 +17,6 @@
|
||||
from twisted.web.http import HTTPClient
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.endpoints import connectProtocol
|
||||
from synapse.http.endpoint import matrix_endpoint
|
||||
import json
|
||||
import logging
|
||||
|
@ -14,6 +14,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .units import Pdu
|
||||
from synapse.crypto.event_signing import (
|
||||
add_event_pdu_content_hash, sign_event_pdu
|
||||
)
|
||||
|
||||
import copy
|
||||
|
||||
@ -33,6 +36,7 @@ def encode_event_id(pdu_id, origin):
|
||||
class PduCodec(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.server_name = hs.hostname
|
||||
self.event_factory = hs.get_event_factory()
|
||||
self.clock = hs.get_clock()
|
||||
@ -43,9 +47,7 @@ class PduCodec(object):
|
||||
kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
|
||||
kwargs["room_id"] = pdu.context
|
||||
kwargs["etype"] = pdu.pdu_type
|
||||
kwargs["prev_events"] = [
|
||||
encode_event_id(p[0], p[1]) for p in pdu.prev_pdus
|
||||
]
|
||||
kwargs["prev_pdus"] = pdu.prev_pdus
|
||||
|
||||
if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
|
||||
kwargs["prev_state"] = encode_event_id(
|
||||
@ -76,11 +78,8 @@ class PduCodec(object):
|
||||
d["context"] = event.room_id
|
||||
d["pdu_type"] = event.type
|
||||
|
||||
if hasattr(event, "prev_events"):
|
||||
d["prev_pdus"] = [
|
||||
decode_event_id(e, self.server_name)
|
||||
for e in event.prev_events
|
||||
]
|
||||
if hasattr(event, "prev_pdus"):
|
||||
d["prev_pdus"] = event.prev_pdus
|
||||
|
||||
if hasattr(event, "prev_state"):
|
||||
d["prev_state_id"], d["prev_state_origin"] = (
|
||||
@ -93,10 +92,12 @@ class PduCodec(object):
|
||||
kwargs = copy.deepcopy(event.unrecognized_keys)
|
||||
kwargs.update({
|
||||
k: v for k, v in d.items()
|
||||
if k not in ["event_id", "room_id", "type", "prev_events"]
|
||||
if k not in ["event_id", "room_id", "type"]
|
||||
})
|
||||
|
||||
if "origin_server_ts" not in kwargs:
|
||||
kwargs["origin_server_ts"] = int(self.clock.time_msec())
|
||||
|
||||
return Pdu(**kwargs)
|
||||
pdu = Pdu(**kwargs)
|
||||
pdu = add_event_pdu_content_hash(pdu)
|
||||
return sign_event_pdu(pdu, self.server_name, self.signing_key)
|
||||
|
@ -297,6 +297,10 @@ class ReplicationLayer(object):
|
||||
transaction = Transaction(**transaction_data)
|
||||
|
||||
for p in transaction.pdus:
|
||||
if "unsigned" in p:
|
||||
unsigned = p["unsigned"]
|
||||
if "age" in unsigned:
|
||||
p["age"] = unsigned["age"]
|
||||
if "age" in p:
|
||||
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
|
||||
del p["age"]
|
||||
@ -467,14 +471,16 @@ class ReplicationLayer(object):
|
||||
transmission.
|
||||
"""
|
||||
pdus = [p.get_dict() for p in pdu_list]
|
||||
time_now = self._clock.time_msec()
|
||||
for p in pdus:
|
||||
if "age_ts" in pdus:
|
||||
p["age"] = int(self.clock.time_msec()) - p["age_ts"]
|
||||
|
||||
if "age_ts" in p:
|
||||
age = time_now - p["age_ts"]
|
||||
p.setdefault("unsigned", {})["age"] = int(age)
|
||||
del p["age_ts"]
|
||||
return Transaction(
|
||||
origin=self.server_name,
|
||||
pdus=pdus,
|
||||
origin_server_ts=int(self._clock.time_msec()),
|
||||
origin_server_ts=int(time_now),
|
||||
destination=None,
|
||||
)
|
||||
|
||||
@ -498,7 +504,7 @@ class ReplicationLayer(object):
|
||||
min_depth = yield self.store.get_min_depth_for_context(pdu.context)
|
||||
|
||||
if min_depth and pdu.depth > min_depth:
|
||||
for pdu_id, origin in pdu.prev_pdus:
|
||||
for pdu_id, origin, hashes in pdu.prev_pdus:
|
||||
exists = yield self._get_persisted_pdu(pdu_id, origin)
|
||||
|
||||
if not exists:
|
||||
@ -654,7 +660,7 @@ class _TransactionQueue(object):
|
||||
logger.debug("TX [%s] Persisting transaction...", destination)
|
||||
|
||||
transaction = Transaction.create_new(
|
||||
origin_server_ts=self._clock.time_msec(),
|
||||
origin_server_ts=int(self._clock.time_msec()),
|
||||
transaction_id=str(self._next_txn_id),
|
||||
origin=self.server_name,
|
||||
destination=destination,
|
||||
@ -679,7 +685,9 @@ class _TransactionQueue(object):
|
||||
if "pdus" in data:
|
||||
for p in data["pdus"]:
|
||||
if "age_ts" in p:
|
||||
p["age"] = now - int(p["age_ts"])
|
||||
unsigned = p.setdefault("unsigned", {})
|
||||
unsigned["age"] = now - int(p["age_ts"])
|
||||
del p["age_ts"]
|
||||
return data
|
||||
|
||||
code, response = yield self.transport_layer.send_transaction(
|
||||
|
@ -18,6 +18,7 @@ server protocol.
|
||||
"""
|
||||
|
||||
from synapse.util.jsonobject import JsonEncodedObject
|
||||
from syutil.base64util import encode_base64
|
||||
|
||||
import logging
|
||||
import json
|
||||
@ -63,9 +64,10 @@ class Pdu(JsonEncodedObject):
|
||||
"depth",
|
||||
"content",
|
||||
"outlier",
|
||||
"hashes",
|
||||
"signatures",
|
||||
"is_state", # Below this are keys valid only for State Pdus.
|
||||
"state_key",
|
||||
"power_level",
|
||||
"prev_state_id",
|
||||
"prev_state_origin",
|
||||
"required_power_level",
|
||||
@ -91,7 +93,7 @@ class Pdu(JsonEncodedObject):
|
||||
# just leaving it as a dict. (OR DO WE?!)
|
||||
|
||||
def __init__(self, destinations=[], is_state=False, prev_pdus=[],
|
||||
outlier=False, **kwargs):
|
||||
outlier=False, hashes={}, signatures={}, **kwargs):
|
||||
if is_state:
|
||||
for required_key in ["state_key"]:
|
||||
if required_key not in kwargs:
|
||||
@ -99,9 +101,11 @@ class Pdu(JsonEncodedObject):
|
||||
|
||||
super(Pdu, self).__init__(
|
||||
destinations=destinations,
|
||||
is_state=is_state,
|
||||
is_state=bool(is_state),
|
||||
prev_pdus=prev_pdus,
|
||||
outlier=outlier,
|
||||
hashes=hashes,
|
||||
signatures=signatures,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@ -120,6 +124,10 @@ class Pdu(JsonEncodedObject):
|
||||
d = copy.copy(pdu_tuple.pdu_entry._asdict())
|
||||
d["origin_server_ts"] = d.pop("ts")
|
||||
|
||||
for k in d.keys():
|
||||
if d[k] is None:
|
||||
del d[k]
|
||||
|
||||
d["content"] = json.loads(d["content_json"])
|
||||
del d["content_json"]
|
||||
|
||||
@ -127,8 +135,28 @@ class Pdu(JsonEncodedObject):
|
||||
if "unrecognized_keys" in d and d["unrecognized_keys"]:
|
||||
args.update(json.loads(d["unrecognized_keys"]))
|
||||
|
||||
hashes = {
|
||||
alg: encode_base64(hsh)
|
||||
for alg, hsh in pdu_tuple.hashes.items()
|
||||
}
|
||||
|
||||
signatures = {
|
||||
kid: encode_base64(sig)
|
||||
for kid, sig in pdu_tuple.signatures.items()
|
||||
}
|
||||
|
||||
prev_pdus = []
|
||||
for prev_pdu in pdu_tuple.prev_pdu_list:
|
||||
prev_hashes = pdu_tuple.edge_hashes.get(prev_pdu, {})
|
||||
prev_hashes = {
|
||||
alg: encode_base64(hsh) for alg, hsh in prev_hashes.items()
|
||||
}
|
||||
prev_pdus.append((prev_pdu[0], prev_pdu[1], prev_hashes))
|
||||
|
||||
return Pdu(
|
||||
prev_pdus=pdu_tuple.prev_pdu_list,
|
||||
prev_pdus=prev_pdus,
|
||||
hashes=hashes,
|
||||
signatures=signatures,
|
||||
**args
|
||||
)
|
||||
else:
|
||||
|
@ -344,7 +344,7 @@ class RoomInitialSyncRestServlet(RestServlet):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
user = yield self.auth.get_user_by_req(request)
|
||||
yield self.auth.get_user_by_req(request)
|
||||
# TODO: Get all the initial sync data for this room and return in the
|
||||
# same format as initial sync, that is:
|
||||
# {
|
||||
|
@ -75,10 +75,6 @@ class StateHandler(object):
|
||||
snapshot.fill_out_prev_events(event)
|
||||
yield self.annotate_state_groups(event)
|
||||
|
||||
event.prev_events = [
|
||||
e for e in event.prev_events if e != event.event_id
|
||||
]
|
||||
|
||||
current_state = snapshot.prev_state_pdu
|
||||
|
||||
if current_state:
|
||||
|
@ -40,7 +40,14 @@ from .stream import StreamStore
|
||||
from .pdu import StatePduStore, PduStore, PdusTable
|
||||
from .transactions import TransactionStore
|
||||
from .keys import KeyStore
|
||||
|
||||
from .state import StateStore
|
||||
from .signatures import SignatureStore
|
||||
|
||||
from syutil.base64util import decode_base64
|
||||
|
||||
from synapse.crypto.event_signing import compute_pdu_event_reference_hash
|
||||
|
||||
|
||||
import json
|
||||
import logging
|
||||
@ -61,6 +68,7 @@ SCHEMAS = [
|
||||
"keys",
|
||||
"redactions",
|
||||
"state",
|
||||
"signatures",
|
||||
]
|
||||
|
||||
|
||||
@ -78,7 +86,7 @@ class _RollbackButIsFineException(Exception):
|
||||
class DataStore(RoomMemberStore, RoomStore,
|
||||
RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
|
||||
PresenceStore, PduStore, StatePduStore, TransactionStore,
|
||||
DirectoryStore, KeyStore, StateStore):
|
||||
DirectoryStore, KeyStore, StateStore, SignatureStore):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(DataStore, self).__init__(hs)
|
||||
@ -146,6 +154,8 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
def _persist_event_pdu_txn(self, txn, pdu):
|
||||
cols = dict(pdu.__dict__)
|
||||
unrec_keys = dict(pdu.unrecognized_keys)
|
||||
del cols["hashes"]
|
||||
del cols["signatures"]
|
||||
del cols["content"]
|
||||
del cols["prev_pdus"]
|
||||
cols["content_json"] = json.dumps(pdu.content)
|
||||
@ -161,6 +171,33 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
|
||||
logger.debug("Persisting: %s", repr(cols))
|
||||
|
||||
for hash_alg, hash_base64 in pdu.hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_pdu_content_hash_txn(
|
||||
txn, pdu.pdu_id, pdu.origin, hash_alg, hash_bytes,
|
||||
)
|
||||
|
||||
signatures = pdu.signatures.get(pdu.origin, {})
|
||||
|
||||
for key_id, signature_base64 in signatures.items():
|
||||
signature_bytes = decode_base64(signature_base64)
|
||||
self._store_pdu_origin_signature_txn(
|
||||
txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes,
|
||||
)
|
||||
|
||||
for prev_pdu_id, prev_origin, prev_hashes in pdu.prev_pdus:
|
||||
for alg, hash_base64 in prev_hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_prev_pdu_hash_txn(
|
||||
txn, pdu.pdu_id, pdu.origin, prev_pdu_id, prev_origin, alg,
|
||||
hash_bytes
|
||||
)
|
||||
|
||||
(ref_alg, ref_hash_bytes) = compute_pdu_event_reference_hash(pdu)
|
||||
self._store_pdu_reference_hash_txn(
|
||||
txn, pdu.pdu_id, pdu.origin, ref_alg, ref_hash_bytes
|
||||
)
|
||||
|
||||
if pdu.is_state:
|
||||
self._persist_state_txn(txn, pdu.prev_pdus, cols)
|
||||
else:
|
||||
@ -338,6 +375,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
prev_pdus = self._get_latest_pdus_in_context(
|
||||
txn, room_id
|
||||
)
|
||||
|
||||
if state_type is not None and state_key is not None:
|
||||
prev_state_pdu = self._get_current_state_pdu(
|
||||
txn, room_id, state_type, state_key
|
||||
@ -387,17 +425,16 @@ class Snapshot(object):
|
||||
self.prev_state_pdu = prev_state_pdu
|
||||
|
||||
def fill_out_prev_events(self, event):
|
||||
if hasattr(event, "prev_events"):
|
||||
if hasattr(event, "prev_pdus"):
|
||||
return
|
||||
|
||||
es = [
|
||||
"%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus
|
||||
event.prev_pdus = [
|
||||
(p_id, origin, hashes)
|
||||
for p_id, origin, hashes, _ in self.prev_pdus
|
||||
]
|
||||
|
||||
event.prev_events = [e for e in es if e != event.event_id]
|
||||
|
||||
if self.prev_pdus:
|
||||
event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1
|
||||
event.depth = max([int(v) for _, _, _, v in self.prev_pdus]) + 1
|
||||
else:
|
||||
event.depth = 0
|
||||
|
||||
|
@ -104,7 +104,6 @@ class KeyStore(SQLBaseStore):
|
||||
ts_now_ms (int): The time now in milliseconds
|
||||
verification_key (VerifyKey): The NACL verify key.
|
||||
"""
|
||||
verify_key_bytes = verify_key.encode()
|
||||
return self._simple_insert(
|
||||
table="server_signature_keys",
|
||||
values={
|
||||
|
@ -20,10 +20,13 @@ from ._base import SQLBaseStore, Table, JoinHelper
|
||||
from synapse.federation.units import Pdu
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
from syutil.base64util import encode_base64
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -64,6 +67,13 @@ class PduStore(SQLBaseStore):
|
||||
for r in PduEdgesTable.decode_results(txn.fetchall())
|
||||
]
|
||||
|
||||
edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin)
|
||||
|
||||
hashes = self._get_pdu_content_hashes_txn(txn, pdu_id, origin)
|
||||
signatures = self._get_pdu_origin_signatures_txn(
|
||||
txn, pdu_id, origin
|
||||
)
|
||||
|
||||
query = (
|
||||
"SELECT %(fields)s FROM %(pdus)s as p "
|
||||
"LEFT JOIN %(state)s as s "
|
||||
@ -80,7 +90,9 @@ class PduStore(SQLBaseStore):
|
||||
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
results.append(PduTuple(PduEntry(*row), edges))
|
||||
results.append(PduTuple(
|
||||
PduEntry(*row), edges, hashes, signatures, edge_hashes
|
||||
))
|
||||
|
||||
return results
|
||||
|
||||
@ -309,9 +321,14 @@ class PduStore(SQLBaseStore):
|
||||
(context, )
|
||||
)
|
||||
|
||||
results = txn.fetchall()
|
||||
results = []
|
||||
for pdu_id, origin, depth in txn.fetchall():
|
||||
hashes = self._get_pdu_reference_hashes_txn(txn, pdu_id, origin)
|
||||
sha256_bytes = hashes["sha256"]
|
||||
prev_hashes = {"sha256": encode_base64(sha256_bytes)}
|
||||
results.append((pdu_id, origin, prev_hashes, depth))
|
||||
|
||||
return [(row[0], row[1], row[2]) for row in results]
|
||||
return results
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_oldest_pdus_in_context(self, context):
|
||||
@ -430,7 +447,7 @@ class PduStore(SQLBaseStore):
|
||||
"DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
|
||||
% PduForwardExtremitiesTable.table_name
|
||||
)
|
||||
txn.executemany(query, prev_pdus)
|
||||
txn.executemany(query, list(p[:2] for p in prev_pdus))
|
||||
|
||||
# We only insert as a forward extremety the new pdu if there are no
|
||||
# other pdus that reference it as a prev pdu
|
||||
@ -453,7 +470,7 @@ class PduStore(SQLBaseStore):
|
||||
# deleted in a second if they're incorrect anyway.
|
||||
txn.executemany(
|
||||
PduBackwardExtremitiesTable.insert_statement(),
|
||||
[(i, o, context) for i, o in prev_pdus]
|
||||
[(i, o, context) for i, o, _ in prev_pdus]
|
||||
)
|
||||
|
||||
# Also delete from the backwards extremities table all ones that
|
||||
@ -914,7 +931,7 @@ This does not include a prev_pdus key.
|
||||
|
||||
PduTuple = namedtuple(
|
||||
"PduTuple",
|
||||
("pdu_entry", "prev_pdu_list")
|
||||
("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes")
|
||||
)
|
||||
""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
|
||||
the `prev_pdus` key of a PDU.
|
||||
|
66
synapse/storage/schema/signatures.sql
Normal file
66
synapse/storage/schema/signatures.sql
Normal file
@ -0,0 +1,66 @@
|
||||
/* Copyright 2014 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pdu_content_hashes (
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
algorithm TEXT,
|
||||
hash BLOB,
|
||||
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS pdu_content_hashes_id ON pdu_content_hashes (
|
||||
pdu_id, origin
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pdu_reference_hashes (
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
algorithm TEXT,
|
||||
hash BLOB,
|
||||
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS pdu_reference_hashes_id ON pdu_reference_hashes (
|
||||
pdu_id, origin
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pdu_origin_signatures (
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
key_id TEXT,
|
||||
signature BLOB,
|
||||
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, key_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS pdu_origin_signatures_id ON pdu_origin_signatures (
|
||||
pdu_id, origin
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pdu_edge_hashes(
|
||||
pdu_id TEXT,
|
||||
origin TEXT,
|
||||
prev_pdu_id TEXT,
|
||||
prev_origin TEXT,
|
||||
algorithm TEXT,
|
||||
hash BLOB,
|
||||
CONSTRAINT uniqueness UNIQUE (
|
||||
pdu_id, origin, prev_pdu_id, prev_origin, algorithm
|
||||
)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS pdu_edge_hashes_id ON pdu_edge_hashes(
|
||||
pdu_id, origin
|
||||
);
|
155
synapse/storage/signatures.py
Normal file
155
synapse/storage/signatures.py
Normal file
@ -0,0 +1,155 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from _base import SQLBaseStore
|
||||
|
||||
|
||||
class SignatureStore(SQLBaseStore):
|
||||
"""Persistence for PDU signatures and hashes"""
|
||||
|
||||
def _get_pdu_content_hashes_txn(self, txn, pdu_id, origin):
|
||||
"""Get all the hashes for a given PDU.
|
||||
Args:
|
||||
txn (cursor):
|
||||
pdu_id (str): Id for the PDU.
|
||||
origin (str): origin of the PDU.
|
||||
Returns:
|
||||
A dict of algorithm -> hash.
|
||||
"""
|
||||
query = (
|
||||
"SELECT algorithm, hash"
|
||||
" FROM pdu_content_hashes"
|
||||
" WHERE pdu_id = ? and origin = ?"
|
||||
)
|
||||
txn.execute(query, (pdu_id, origin))
|
||||
return dict(txn.fetchall())
|
||||
|
||||
def _store_pdu_content_hash_txn(self, txn, pdu_id, origin, algorithm,
|
||||
hash_bytes):
|
||||
"""Store a hash for a PDU
|
||||
Args:
|
||||
txn (cursor):
|
||||
pdu_id (str): Id for the PDU.
|
||||
origin (str): origin of the PDU.
|
||||
algorithm (str): Hashing algorithm.
|
||||
hash_bytes (bytes): Hash function output bytes.
|
||||
"""
|
||||
self._simple_insert_txn(txn, "pdu_content_hashes", {
|
||||
"pdu_id": pdu_id,
|
||||
"origin": origin,
|
||||
"algorithm": algorithm,
|
||||
"hash": buffer(hash_bytes),
|
||||
})
|
||||
|
||||
def _get_pdu_reference_hashes_txn(self, txn, pdu_id, origin):
|
||||
"""Get all the hashes for a given PDU.
|
||||
Args:
|
||||
txn (cursor):
|
||||
pdu_id (str): Id for the PDU.
|
||||
origin (str): origin of the PDU.
|
||||
Returns:
|
||||
A dict of algorithm -> hash.
|
||||
"""
|
||||
query = (
|
||||
"SELECT algorithm, hash"
|
||||
" FROM pdu_reference_hashes"
|
||||
" WHERE pdu_id = ? and origin = ?"
|
||||
)
|
||||
txn.execute(query, (pdu_id, origin))
|
||||
return dict(txn.fetchall())
|
||||
|
||||
def _store_pdu_reference_hash_txn(self, txn, pdu_id, origin, algorithm,
|
||||
hash_bytes):
|
||||
"""Store a hash for a PDU
|
||||
Args:
|
||||
txn (cursor):
|
||||
pdu_id (str): Id for the PDU.
|
||||
origin (str): origin of the PDU.
|
||||
algorithm (str): Hashing algorithm.
|
||||
hash_bytes (bytes): Hash function output bytes.
|
||||
"""
|
||||
self._simple_insert_txn(txn, "pdu_reference_hashes", {
|
||||
"pdu_id": pdu_id,
|
||||
"origin": origin,
|
||||
"algorithm": algorithm,
|
||||
"hash": buffer(hash_bytes),
|
||||
})
|
||||
|
||||
|
||||
def _get_pdu_origin_signatures_txn(self, txn, pdu_id, origin):
|
||||
"""Get all the signatures for a given PDU.
|
||||
Args:
|
||||
txn (cursor):
|
||||
pdu_id (str): Id for the PDU.
|
||||
origin (str): origin of the PDU.
|
||||
Returns:
|
||||
A dict of key_id -> signature_bytes.
|
||||
"""
|
||||
query = (
|
||||
"SELECT key_id, signature"
|
||||
" FROM pdu_origin_signatures"
|
||||
" WHERE pdu_id = ? and origin = ?"
|
||||
)
|
||||
txn.execute(query, (pdu_id, origin))
|
||||
return dict(txn.fetchall())
|
||||
|
||||
def _store_pdu_origin_signature_txn(self, txn, pdu_id, origin, key_id,
|
||||
signature_bytes):
|
||||
"""Store a signature from the origin server for a PDU.
|
||||
Args:
|
||||
txn (cursor):
|
||||
pdu_id (str): Id for the PDU.
|
||||
origin (str): origin of the PDU.
|
||||
key_id (str): Id for the signing key.
|
||||
signature (bytes): The signature.
|
||||
"""
|
||||
self._simple_insert_txn(txn, "pdu_origin_signatures", {
|
||||
"pdu_id": pdu_id,
|
||||
"origin": origin,
|
||||
"key_id": key_id,
|
||||
"signature": buffer(signature_bytes),
|
||||
})
|
||||
|
||||
def _get_prev_pdu_hashes_txn(self, txn, pdu_id, origin):
|
||||
"""Get all the hashes for previous PDUs of a PDU
|
||||
Args:
|
||||
txn (cursor):
|
||||
pdu_id (str): Id of the PDU.
|
||||
origin (str): Origin of the PDU.
|
||||
Returns:
|
||||
dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
|
||||
"""
|
||||
query = (
|
||||
"SELECT prev_pdu_id, prev_origin, algorithm, hash"
|
||||
" FROM pdu_edge_hashes"
|
||||
" WHERE pdu_id = ? and origin = ?"
|
||||
)
|
||||
txn.execute(query, (pdu_id, origin))
|
||||
results = {}
|
||||
for prev_pdu_id, prev_origin, algorithm, hash_bytes in txn.fetchall():
|
||||
hashes = results.setdefault((prev_pdu_id, prev_origin), {})
|
||||
hashes[algorithm] = hash_bytes
|
||||
return results
|
||||
|
||||
def _store_prev_pdu_hash_txn(self, txn, pdu_id, origin, prev_pdu_id,
|
||||
prev_origin, algorithm, hash_bytes):
|
||||
self._simple_insert_txn(txn, "pdu_edge_hashes", {
|
||||
"pdu_id": pdu_id,
|
||||
"origin": origin,
|
||||
"prev_pdu_id": prev_pdu_id,
|
||||
"prev_origin": prev_origin,
|
||||
"algorithm": algorithm,
|
||||
"hash": buffer(hash_bytes),
|
||||
})
|
1
synapse/test_pyflakes.py
Normal file
1
synapse/test_pyflakes.py
Normal file
@ -0,0 +1 @@
|
||||
import an_unused_module
|
@ -41,7 +41,7 @@ def make_pdu(prev_pdus=[], **kwargs):
|
||||
}
|
||||
pdu_fields.update(kwargs)
|
||||
|
||||
return PduTuple(PduEntry(**pdu_fields), prev_pdus)
|
||||
return PduTuple(PduEntry(**pdu_fields), prev_pdus, {}, {}, {})
|
||||
|
||||
|
||||
class FederationTestCase(unittest.TestCase):
|
||||
@ -183,6 +183,8 @@ class FederationTestCase(unittest.TestCase):
|
||||
"is_state": False,
|
||||
"content": {"testing": "content here"},
|
||||
"depth": 1,
|
||||
"hashes": {},
|
||||
"signatures": {},
|
||||
},
|
||||
]
|
||||
},
|
||||
|
@ -23,14 +23,21 @@ from synapse.federation.units import Pdu
|
||||
|
||||
from synapse.server import HomeServer
|
||||
|
||||
from mock import Mock
|
||||
from mock import Mock, NonCallableMock
|
||||
|
||||
from ..utils import MockKey
|
||||
|
||||
|
||||
class PduCodecTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.hs = HomeServer("blargle.net")
|
||||
self.event_factory = self.hs.get_event_factory()
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
self.hs = HomeServer(
|
||||
"blargle.net",
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.event_factory = self.hs.get_event_factory()
|
||||
self.codec = PduCodec(self.hs)
|
||||
|
||||
def test_decode_event_id(self):
|
||||
@ -81,7 +88,7 @@ class PduCodecTestCase(unittest.TestCase):
|
||||
self.assertEquals(pdu.context, event.room_id)
|
||||
self.assertEquals(pdu.is_state, event.is_state)
|
||||
self.assertEquals(pdu.depth, event.depth)
|
||||
self.assertEquals(["alice@bob.com"], event.prev_events)
|
||||
self.assertEquals(pdu.prev_pdus, event.prev_pdus)
|
||||
self.assertEquals(pdu.content, event.content)
|
||||
|
||||
def test_pdu_from_event(self):
|
||||
@ -137,7 +144,7 @@ class PduCodecTestCase(unittest.TestCase):
|
||||
self.assertEquals(pdu.context, event.room_id)
|
||||
self.assertEquals(pdu.is_state, event.is_state)
|
||||
self.assertEquals(pdu.depth, event.depth)
|
||||
self.assertEquals(["alice@bob.com"], event.prev_events)
|
||||
self.assertEquals(pdu.prev_pdus, event.prev_pdus)
|
||||
self.assertEquals(pdu.content, event.content)
|
||||
self.assertEquals(pdu.state_key, event.state_key)
|
||||
|
||||
|
@ -28,7 +28,7 @@ from synapse.server import HomeServer
|
||||
# python imports
|
||||
import json
|
||||
|
||||
from ..utils import MockHttpResource, MemoryDataStore
|
||||
from ..utils import MockHttpResource, MemoryDataStore, MockKey
|
||||
from .utils import RestTestCase
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
@ -122,6 +122,9 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer(
|
||||
"test",
|
||||
db_pool=None,
|
||||
@ -139,7 +142,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=NonCallableMock(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
|
@ -18,9 +18,9 @@
|
||||
from tests import unittest
|
||||
from twisted.internet import defer
|
||||
|
||||
from mock import Mock
|
||||
from mock import Mock, NonCallableMock
|
||||
|
||||
from ..utils import MockHttpResource
|
||||
from ..utils import MockHttpResource, MockKey
|
||||
|
||||
from synapse.api.errors import SynapseError, AuthError
|
||||
from synapse.server import HomeServer
|
||||
@ -41,6 +41,9 @@ class ProfileTestCase(unittest.TestCase):
|
||||
"set_avatar_url",
|
||||
])
|
||||
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer("test",
|
||||
db_pool=None,
|
||||
http_client=None,
|
||||
@ -48,6 +51,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||
federation=Mock(),
|
||||
replication_layer=Mock(),
|
||||
datastore=None,
|
||||
config=self.mock_config,
|
||||
)
|
||||
|
||||
def _get_user_by_req(request=None):
|
||||
|
@ -27,7 +27,7 @@ from synapse.server import HomeServer
|
||||
import json
|
||||
import urllib
|
||||
|
||||
from ..utils import MockHttpResource, MemoryDataStore
|
||||
from ..utils import MockHttpResource, MemoryDataStore, MockKey
|
||||
from .utils import RestTestCase
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
@ -50,6 +50,9 @@ class RoomPermissionsTestCase(RestTestCase):
|
||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer(
|
||||
"red",
|
||||
db_pool=None,
|
||||
@ -61,7 +64,7 @@ class RoomPermissionsTestCase(RestTestCase):
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=NonCallableMock(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
@ -408,6 +411,9 @@ class RoomsMemberListTestCase(RestTestCase):
|
||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer(
|
||||
"red",
|
||||
db_pool=None,
|
||||
@ -419,7 +425,7 @@ class RoomsMemberListTestCase(RestTestCase):
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=NonCallableMock(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
@ -497,6 +503,9 @@ class RoomsCreateTestCase(RestTestCase):
|
||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer(
|
||||
"red",
|
||||
db_pool=None,
|
||||
@ -508,7 +517,7 @@ class RoomsCreateTestCase(RestTestCase):
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=NonCallableMock(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
@ -598,6 +607,9 @@ class RoomTopicTestCase(RestTestCase):
|
||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer(
|
||||
"red",
|
||||
db_pool=None,
|
||||
@ -609,7 +621,7 @@ class RoomTopicTestCase(RestTestCase):
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=NonCallableMock(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
@ -712,6 +724,9 @@ class RoomMemberStateTestCase(RestTestCase):
|
||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer(
|
||||
"red",
|
||||
db_pool=None,
|
||||
@ -723,7 +738,7 @@ class RoomMemberStateTestCase(RestTestCase):
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=NonCallableMock(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
@ -853,6 +868,9 @@ class RoomMessagesTestCase(RestTestCase):
|
||||
persistence_service = Mock(spec=["get_latest_pdus_in_context"])
|
||||
persistence_service.get_latest_pdus_in_context.return_value = []
|
||||
|
||||
self.mock_config = NonCallableMock()
|
||||
self.mock_config.signing_key = [MockKey()]
|
||||
|
||||
hs = HomeServer(
|
||||
"red",
|
||||
db_pool=None,
|
||||
@ -864,7 +882,7 @@ class RoomMessagesTestCase(RestTestCase):
|
||||
ratelimiter=NonCallableMock(spec_set=[
|
||||
"send_message",
|
||||
]),
|
||||
config=NonCallableMock(),
|
||||
config=self.mock_config,
|
||||
)
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
|
@ -118,13 +118,14 @@ class MockHttpResource(HttpServer):
|
||||
class MockKey(object):
|
||||
alg = "mock_alg"
|
||||
version = "mock_version"
|
||||
signature = b"\x9a\x87$"
|
||||
|
||||
@property
|
||||
def verify_key(self):
|
||||
return self
|
||||
|
||||
def sign(self, message):
|
||||
return b"\x9a\x87$"
|
||||
return self
|
||||
|
||||
def verify(self, message, sig):
|
||||
assert sig == b"\x9a\x87$"
|
||||
|
Loading…
Reference in New Issue
Block a user