Merge branch 'event_signing' of github.com:matrix-org/synapse into federation_authorization

Conflicts:
	synapse/storage/__init__.py
This commit is contained in:
Erik Johnston 2014-10-27 11:58:32 +00:00
commit ad9226eeec
24 changed files with 580 additions and 77 deletions

View File

@ -1,13 +1,13 @@
Signing JSON Signing JSON
============ ============
JSON is signed by encoding the JSON object without ``signatures`` or ``meta`` JSON is signed by encoding the JSON object without ``signatures`` or ``unsigned``
keys using a canonical encoding. The JSON bytes are then signed using the keys using a canonical encoding. The JSON bytes are then signed using the
signature algorithm and the signature encoded using base64 with the padding signature algorithm and the signature encoded using base64 with the padding
stripped. The resulting base64 signature is added to an object under the stripped. The resulting base64 signature is added to an object under the
*signing key identifier* which is added to the ``signatures`` object under the *signing key identifier* which is added to the ``signatures`` object under the
name of the server signing it which is added back to the original JSON object name of the server signing it which is added back to the original JSON object
along with the ``meta`` object. along with the ``unsigned`` object.
The *signing key identifier* is the concatenation of the *signing algorithm* The *signing key identifier* is the concatenation of the *signing algorithm*
and a *key version*. The *signing algorithm* identifies the algorithm used to and a *key version*. The *signing algorithm* identifies the algorithm used to
@ -15,8 +15,8 @@ sign the JSON. The currently support value for *signing algorithm* is
``ed25519`` as implemented by NACL (http://nacl.cr.yp.to/). The *key version* ``ed25519`` as implemented by NACL (http://nacl.cr.yp.to/). The *key version*
is used to distinguish between different signing keys used by the same entity. is used to distinguish between different signing keys used by the same entity.
The ``meta`` object and the ``signatures`` object are not covered by the The ``unsigned`` object and the ``signatures`` object are not covered by the
signature. Therefore intermediate servers can add metadata such as time stamps signature. Therefore intermediate servers can add unsigneddata such as time stamps
and additional signatures. and additional signatures.
@ -27,7 +27,7 @@ and additional signatures.
"signing_keys": { "signing_keys": {
"ed25519:1": "XSl0kuyvrXNj6A+7/tkrB9sxSbRi08Of5uRhxOqZtEQ" "ed25519:1": "XSl0kuyvrXNj6A+7/tkrB9sxSbRi08Of5uRhxOqZtEQ"
}, },
"meta": { "unsigned": {
"retrieved_ts_ms": 922834800000 "retrieved_ts_ms": 922834800000
}, },
"signatures": { "signatures": {
@ -41,7 +41,7 @@ and additional signatures.
def sign_json(json_object, signing_key, signing_name): def sign_json(json_object, signing_key, signing_name):
signatures = json_object.pop("signatures", {}) signatures = json_object.pop("signatures", {})
meta = json_object.pop("meta", None) unsigned = json_object.pop("unsigned", None)
signed = signing_key.sign(encode_canonical_json(json_object)) signed = signing_key.sign(encode_canonical_json(json_object))
signature_base64 = encode_base64(signed.signature) signature_base64 = encode_base64(signed.signature)
@ -50,8 +50,8 @@ and additional signatures.
signatures.setdefault(sigature_name, {})[key_id] = signature_base64 signatures.setdefault(sigature_name, {})[key_id] = signature_base64
json_object["signatures"] = signatures json_object["signatures"] = signatures
if meta is not None: if unsigned is not None:
json_object["meta"] = meta json_object["unsigned"] = unsigned
return json_object return json_object

69
scripts/hash_history.py Normal file
View File

@ -0,0 +1,69 @@
from synapse.storage.pdu import PduStore
from synapse.storage.signatures import SignatureStore
from synapse.storage._base import SQLBaseStore
from synapse.federation.units import Pdu
from synapse.crypto.event_signing import (
add_event_pdu_content_hash, compute_pdu_event_reference_hash
)
from synapse.api.events.utils import prune_pdu
from syutil.base64util import encode_base64, decode_base64
from syutil.jsonutil import encode_canonical_json
import sqlite3
import sys
class Store(object):
_get_pdu_tuples = PduStore.__dict__["_get_pdu_tuples"]
_get_pdu_content_hashes_txn = SignatureStore.__dict__["_get_pdu_content_hashes_txn"]
_get_prev_pdu_hashes_txn = SignatureStore.__dict__["_get_prev_pdu_hashes_txn"]
_get_pdu_origin_signatures_txn = SignatureStore.__dict__["_get_pdu_origin_signatures_txn"]
_store_pdu_content_hash_txn = SignatureStore.__dict__["_store_pdu_content_hash_txn"]
_store_pdu_reference_hash_txn = SignatureStore.__dict__["_store_pdu_reference_hash_txn"]
_store_prev_pdu_hash_txn = SignatureStore.__dict__["_store_prev_pdu_hash_txn"]
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
store = Store()
def select_pdus(cursor):
cursor.execute(
"SELECT pdu_id, origin FROM pdus ORDER BY depth ASC"
)
ids = cursor.fetchall()
pdu_tuples = store._get_pdu_tuples(cursor, ids)
pdus = [Pdu.from_pdu_tuple(p) for p in pdu_tuples]
reference_hashes = {}
for pdu in pdus:
try:
if pdu.prev_pdus:
print "PROCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus
for pdu_id, origin, hashes in pdu.prev_pdus:
ref_alg, ref_hsh = reference_hashes[(pdu_id, origin)]
hashes[ref_alg] = encode_base64(ref_hsh)
store._store_prev_pdu_hash_txn(cursor, pdu.pdu_id, pdu.origin, pdu_id, origin, ref_alg, ref_hsh)
print "SUCCESS", pdu.pdu_id, pdu.origin, pdu.prev_pdus
pdu = add_event_pdu_content_hash(pdu)
ref_alg, ref_hsh = compute_pdu_event_reference_hash(pdu)
reference_hashes[(pdu.pdu_id, pdu.origin)] = (ref_alg, ref_hsh)
store._store_pdu_reference_hash_txn(cursor, pdu.pdu_id, pdu.origin, ref_alg, ref_hsh)
for alg, hsh_base64 in pdu.hashes.items():
print alg, hsh_base64
store._store_pdu_content_hash_txn(cursor, pdu.pdu_id, pdu.origin, alg, decode_base64(hsh_base64))
except:
print "FAILED_", pdu.pdu_id, pdu.origin, pdu.prev_pdus
def main():
conn = sqlite3.connect(sys.argv[1])
cursor = conn.cursor()
select_pdus(cursor)
conn.commit()
if __name__=='__main__':
main()

View File

@ -65,13 +65,13 @@ class SynapseEvent(JsonEncodedObject):
internal_keys = [ internal_keys = [
"is_state", "is_state",
"prev_events",
"depth", "depth",
"destinations", "destinations",
"origin", "origin",
"outlier", "outlier",
"power_level", "power_level",
"redacted", "redacted",
"prev_pdus",
] ]
required_keys = [ required_keys = [

View File

@ -27,7 +27,14 @@ def prune_event(event):
the user has specified, but we do want to keep necessary information like the user has specified, but we do want to keep necessary information like
type, state_key etc. type, state_key etc.
""" """
return _prune_event_or_pdu(event.type, event)
def prune_pdu(pdu):
"""Removes keys that contain unrestricted and non-essential data from a PDU
"""
return _prune_event_or_pdu(pdu.pdu_type, pdu)
def _prune_event_or_pdu(event_type, event):
# Remove all extraneous fields. # Remove all extraneous fields.
event.unrecognized_keys = {} event.unrecognized_keys = {}
@ -38,25 +45,25 @@ def prune_event(event):
if field in event.content: if field in event.content:
new_content[field] = event.content[field] new_content[field] = event.content[field]
if event.type == RoomMemberEvent.TYPE: if event_type == RoomMemberEvent.TYPE:
add_fields("membership") add_fields("membership")
elif event.type == RoomCreateEvent.TYPE: elif event_type == RoomCreateEvent.TYPE:
add_fields("creator") add_fields("creator")
elif event.type == RoomJoinRulesEvent.TYPE: elif event_type == RoomJoinRulesEvent.TYPE:
add_fields("join_rule") add_fields("join_rule")
elif event.type == RoomPowerLevelsEvent.TYPE: elif event_type == RoomPowerLevelsEvent.TYPE:
# TODO: Actually check these are valid user_ids etc. # TODO: Actually check these are valid user_ids etc.
add_fields("default") add_fields("default")
for k, v in event.content.items(): for k, v in event.content.items():
if k.startswith("@") and isinstance(v, (int, long)): if k.startswith("@") and isinstance(v, (int, long)):
new_content[k] = v new_content[k] = v
elif event.type == RoomAddStateLevelEvent.TYPE: elif event_type == RoomAddStateLevelEvent.TYPE:
add_fields("level") add_fields("level")
elif event.type == RoomSendEventLevelEvent.TYPE: elif event_type == RoomSendEventLevelEvent.TYPE:
add_fields("level") add_fields("level")
elif event.type == RoomOpsPowerLevelsEvent.TYPE: elif event_type == RoomOpsPowerLevelsEvent.TYPE:
add_fields("kick_level", "ban_level", "redact_level") add_fields("kick_level", "ban_level", "redact_level")
elif event.type == RoomAliasesEvent.TYPE: elif event_type == RoomAliasesEvent.TYPE:
add_fields("aliases") add_fields("aliases")
event.content = new_content event.content = new_content

View File

@ -74,7 +74,7 @@ class ServerConfig(Config):
return syutil.crypto.signing_key.read_signing_keys( return syutil.crypto.signing_key.read_signing_keys(
signing_keys.splitlines(True) signing_keys.splitlines(True)
) )
except Exception as e: except Exception:
raise ConfigError( raise ConfigError(
"Error reading signing_key." "Error reading signing_key."
" Try running again with --generate-config" " Try running again with --generate-config"

View File

@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.federation.units import Pdu
from synapse.api.events.utils import prune_pdu
from syutil.jsonutil import encode_canonical_json
from syutil.base64util import encode_base64, decode_base64
from syutil.crypto.jsonsign import sign_json, verify_signed_json
import hashlib
import logging
logger = logging.getLogger(__name__)
def add_event_pdu_content_hash(pdu, hash_algorithm=hashlib.sha256):
hashed = _compute_content_hash(pdu, hash_algorithm)
pdu.hashes[hashed.name] = encode_base64(hashed.digest())
return pdu
def check_event_pdu_content_hash(pdu, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents"""
computed_hash = _compute_content_hash(pdu, hash_algorithm)
if computed_hash.name not in pdu.hashes:
raise Exception("Algorithm %s not in hashes %s" % (
computed_hash.name, list(pdu.hashes)
))
message_hash_base64 = pdu.hashes[computed_hash.name]
try:
message_hash_bytes = decode_base64(message_hash_base64)
except:
raise Exception("Invalid base64: %s" % (message_hash_base64,))
return message_hash_bytes == computed_hash.digest()
def _compute_content_hash(pdu, hash_algorithm):
pdu_json = pdu.get_dict()
#TODO: Make "age_ts" key internal
pdu_json.pop("age_ts", None)
pdu_json.pop("unsigned", None)
pdu_json.pop("signatures", None)
pdu_json.pop("hashes", None)
pdu_json_bytes = encode_canonical_json(pdu_json)
return hash_algorithm(pdu_json_bytes)
def compute_pdu_event_reference_hash(pdu, hash_algorithm=hashlib.sha256):
tmp_pdu = Pdu(**pdu.get_dict())
tmp_pdu = prune_pdu(tmp_pdu)
pdu_json = tmp_pdu.get_dict()
pdu_json.pop("signatures", None)
pdu_json_bytes = encode_canonical_json(pdu_json)
hashed = hash_algorithm(pdu_json_bytes)
return (hashed.name, hashed.digest())
def sign_event_pdu(pdu, signature_name, signing_key):
tmp_pdu = Pdu(**pdu.get_dict())
tmp_pdu = prune_pdu(tmp_pdu)
pdu_json = tmp_pdu.get_dict()
pdu_json = sign_json(pdu_json, signature_name, signing_key)
pdu.signatures = pdu_json["signatures"]
return pdu
def verify_signed_event_pdu(pdu, signature_name, verify_key):
tmp_pdu = Pdu(**pdu.get_dict())
tmp_pdu = prune_pdu(tmp_pdu)
pdu_json = tmp_pdu.get_dict()
verify_signed_json(pdu_json, signature_name, verify_key)

View File

@ -17,7 +17,6 @@
from twisted.web.http import HTTPClient from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.endpoints import connectProtocol
from synapse.http.endpoint import matrix_endpoint from synapse.http.endpoint import matrix_endpoint
import json import json
import logging import logging

View File

@ -14,6 +14,9 @@
# limitations under the License. # limitations under the License.
from .units import Pdu from .units import Pdu
from synapse.crypto.event_signing import (
add_event_pdu_content_hash, sign_event_pdu
)
import copy import copy
@ -33,6 +36,7 @@ def encode_event_id(pdu_id, origin):
class PduCodec(object): class PduCodec(object):
def __init__(self, hs): def __init__(self, hs):
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname self.server_name = hs.hostname
self.event_factory = hs.get_event_factory() self.event_factory = hs.get_event_factory()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -43,9 +47,7 @@ class PduCodec(object):
kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin) kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
kwargs["room_id"] = pdu.context kwargs["room_id"] = pdu.context
kwargs["etype"] = pdu.pdu_type kwargs["etype"] = pdu.pdu_type
kwargs["prev_events"] = [ kwargs["prev_pdus"] = pdu.prev_pdus
encode_event_id(p[0], p[1]) for p in pdu.prev_pdus
]
if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"): if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
kwargs["prev_state"] = encode_event_id( kwargs["prev_state"] = encode_event_id(
@ -76,11 +78,8 @@ class PduCodec(object):
d["context"] = event.room_id d["context"] = event.room_id
d["pdu_type"] = event.type d["pdu_type"] = event.type
if hasattr(event, "prev_events"): if hasattr(event, "prev_pdus"):
d["prev_pdus"] = [ d["prev_pdus"] = event.prev_pdus
decode_event_id(e, self.server_name)
for e in event.prev_events
]
if hasattr(event, "prev_state"): if hasattr(event, "prev_state"):
d["prev_state_id"], d["prev_state_origin"] = ( d["prev_state_id"], d["prev_state_origin"] = (
@ -93,10 +92,12 @@ class PduCodec(object):
kwargs = copy.deepcopy(event.unrecognized_keys) kwargs = copy.deepcopy(event.unrecognized_keys)
kwargs.update({ kwargs.update({
k: v for k, v in d.items() k: v for k, v in d.items()
if k not in ["event_id", "room_id", "type", "prev_events"] if k not in ["event_id", "room_id", "type"]
}) })
if "origin_server_ts" not in kwargs: if "origin_server_ts" not in kwargs:
kwargs["origin_server_ts"] = int(self.clock.time_msec()) kwargs["origin_server_ts"] = int(self.clock.time_msec())
return Pdu(**kwargs) pdu = Pdu(**kwargs)
pdu = add_event_pdu_content_hash(pdu)
return sign_event_pdu(pdu, self.server_name, self.signing_key)

View File

@ -297,6 +297,10 @@ class ReplicationLayer(object):
transaction = Transaction(**transaction_data) transaction = Transaction(**transaction_data)
for p in transaction.pdus: for p in transaction.pdus:
if "unsigned" in p:
unsigned = p["unsigned"]
if "age" in unsigned:
p["age"] = unsigned["age"]
if "age" in p: if "age" in p:
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
del p["age"] del p["age"]
@ -467,14 +471,16 @@ class ReplicationLayer(object):
transmission. transmission.
""" """
pdus = [p.get_dict() for p in pdu_list] pdus = [p.get_dict() for p in pdu_list]
time_now = self._clock.time_msec()
for p in pdus: for p in pdus:
if "age_ts" in pdus: if "age_ts" in p:
p["age"] = int(self.clock.time_msec()) - p["age_ts"] age = time_now - p["age_ts"]
p.setdefault("unsigned", {})["age"] = int(age)
del p["age_ts"]
return Transaction( return Transaction(
origin=self.server_name, origin=self.server_name,
pdus=pdus, pdus=pdus,
origin_server_ts=int(self._clock.time_msec()), origin_server_ts=int(time_now),
destination=None, destination=None,
) )
@ -498,7 +504,7 @@ class ReplicationLayer(object):
min_depth = yield self.store.get_min_depth_for_context(pdu.context) min_depth = yield self.store.get_min_depth_for_context(pdu.context)
if min_depth and pdu.depth > min_depth: if min_depth and pdu.depth > min_depth:
for pdu_id, origin in pdu.prev_pdus: for pdu_id, origin, hashes in pdu.prev_pdus:
exists = yield self._get_persisted_pdu(pdu_id, origin) exists = yield self._get_persisted_pdu(pdu_id, origin)
if not exists: if not exists:
@ -654,7 +660,7 @@ class _TransactionQueue(object):
logger.debug("TX [%s] Persisting transaction...", destination) logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new( transaction = Transaction.create_new(
origin_server_ts=self._clock.time_msec(), origin_server_ts=int(self._clock.time_msec()),
transaction_id=str(self._next_txn_id), transaction_id=str(self._next_txn_id),
origin=self.server_name, origin=self.server_name,
destination=destination, destination=destination,
@ -679,7 +685,9 @@ class _TransactionQueue(object):
if "pdus" in data: if "pdus" in data:
for p in data["pdus"]: for p in data["pdus"]:
if "age_ts" in p: if "age_ts" in p:
p["age"] = now - int(p["age_ts"]) unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data return data
code, response = yield self.transport_layer.send_transaction( code, response = yield self.transport_layer.send_transaction(

View File

@ -18,6 +18,7 @@ server protocol.
""" """
from synapse.util.jsonobject import JsonEncodedObject from synapse.util.jsonobject import JsonEncodedObject
from syutil.base64util import encode_base64
import logging import logging
import json import json
@ -63,9 +64,10 @@ class Pdu(JsonEncodedObject):
"depth", "depth",
"content", "content",
"outlier", "outlier",
"hashes",
"signatures",
"is_state", # Below this are keys valid only for State Pdus. "is_state", # Below this are keys valid only for State Pdus.
"state_key", "state_key",
"power_level",
"prev_state_id", "prev_state_id",
"prev_state_origin", "prev_state_origin",
"required_power_level", "required_power_level",
@ -91,7 +93,7 @@ class Pdu(JsonEncodedObject):
# just leaving it as a dict. (OR DO WE?!) # just leaving it as a dict. (OR DO WE?!)
def __init__(self, destinations=[], is_state=False, prev_pdus=[], def __init__(self, destinations=[], is_state=False, prev_pdus=[],
outlier=False, **kwargs): outlier=False, hashes={}, signatures={}, **kwargs):
if is_state: if is_state:
for required_key in ["state_key"]: for required_key in ["state_key"]:
if required_key not in kwargs: if required_key not in kwargs:
@ -99,9 +101,11 @@ class Pdu(JsonEncodedObject):
super(Pdu, self).__init__( super(Pdu, self).__init__(
destinations=destinations, destinations=destinations,
is_state=is_state, is_state=bool(is_state),
prev_pdus=prev_pdus, prev_pdus=prev_pdus,
outlier=outlier, outlier=outlier,
hashes=hashes,
signatures=signatures,
**kwargs **kwargs
) )
@ -120,6 +124,10 @@ class Pdu(JsonEncodedObject):
d = copy.copy(pdu_tuple.pdu_entry._asdict()) d = copy.copy(pdu_tuple.pdu_entry._asdict())
d["origin_server_ts"] = d.pop("ts") d["origin_server_ts"] = d.pop("ts")
for k in d.keys():
if d[k] is None:
del d[k]
d["content"] = json.loads(d["content_json"]) d["content"] = json.loads(d["content_json"])
del d["content_json"] del d["content_json"]
@ -127,8 +135,28 @@ class Pdu(JsonEncodedObject):
if "unrecognized_keys" in d and d["unrecognized_keys"]: if "unrecognized_keys" in d and d["unrecognized_keys"]:
args.update(json.loads(d["unrecognized_keys"])) args.update(json.loads(d["unrecognized_keys"]))
hashes = {
alg: encode_base64(hsh)
for alg, hsh in pdu_tuple.hashes.items()
}
signatures = {
kid: encode_base64(sig)
for kid, sig in pdu_tuple.signatures.items()
}
prev_pdus = []
for prev_pdu in pdu_tuple.prev_pdu_list:
prev_hashes = pdu_tuple.edge_hashes.get(prev_pdu, {})
prev_hashes = {
alg: encode_base64(hsh) for alg, hsh in prev_hashes.items()
}
prev_pdus.append((prev_pdu[0], prev_pdu[1], prev_hashes))
return Pdu( return Pdu(
prev_pdus=pdu_tuple.prev_pdu_list, prev_pdus=prev_pdus,
hashes=hashes,
signatures=signatures,
**args **args
) )
else: else:

View File

@ -344,7 +344,7 @@ class RoomInitialSyncRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user = yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
# TODO: Get all the initial sync data for this room and return in the # TODO: Get all the initial sync data for this room and return in the
# same format as initial sync, that is: # same format as initial sync, that is:
# { # {

View File

@ -75,10 +75,6 @@ class StateHandler(object):
snapshot.fill_out_prev_events(event) snapshot.fill_out_prev_events(event)
yield self.annotate_state_groups(event) yield self.annotate_state_groups(event)
event.prev_events = [
e for e in event.prev_events if e != event.event_id
]
current_state = snapshot.prev_state_pdu current_state = snapshot.prev_state_pdu
if current_state: if current_state:

View File

@ -40,7 +40,14 @@ from .stream import StreamStore
from .pdu import StatePduStore, PduStore, PdusTable from .pdu import StatePduStore, PduStore, PdusTable
from .transactions import TransactionStore from .transactions import TransactionStore
from .keys import KeyStore from .keys import KeyStore
from .state import StateStore from .state import StateStore
from .signatures import SignatureStore
from syutil.base64util import decode_base64
from synapse.crypto.event_signing import compute_pdu_event_reference_hash
import json import json
import logging import logging
@ -61,6 +68,7 @@ SCHEMAS = [
"keys", "keys",
"redactions", "redactions",
"state", "state",
"signatures",
] ]
@ -78,7 +86,7 @@ class _RollbackButIsFineException(Exception):
class DataStore(RoomMemberStore, RoomStore, class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
PresenceStore, PduStore, StatePduStore, TransactionStore, PresenceStore, PduStore, StatePduStore, TransactionStore,
DirectoryStore, KeyStore, StateStore): DirectoryStore, KeyStore, StateStore, SignatureStore):
def __init__(self, hs): def __init__(self, hs):
super(DataStore, self).__init__(hs) super(DataStore, self).__init__(hs)
@ -146,6 +154,8 @@ class DataStore(RoomMemberStore, RoomStore,
def _persist_event_pdu_txn(self, txn, pdu): def _persist_event_pdu_txn(self, txn, pdu):
cols = dict(pdu.__dict__) cols = dict(pdu.__dict__)
unrec_keys = dict(pdu.unrecognized_keys) unrec_keys = dict(pdu.unrecognized_keys)
del cols["hashes"]
del cols["signatures"]
del cols["content"] del cols["content"]
del cols["prev_pdus"] del cols["prev_pdus"]
cols["content_json"] = json.dumps(pdu.content) cols["content_json"] = json.dumps(pdu.content)
@ -161,6 +171,33 @@ class DataStore(RoomMemberStore, RoomStore,
logger.debug("Persisting: %s", repr(cols)) logger.debug("Persisting: %s", repr(cols))
for hash_alg, hash_base64 in pdu.hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_pdu_content_hash_txn(
txn, pdu.pdu_id, pdu.origin, hash_alg, hash_bytes,
)
signatures = pdu.signatures.get(pdu.origin, {})
for key_id, signature_base64 in signatures.items():
signature_bytes = decode_base64(signature_base64)
self._store_pdu_origin_signature_txn(
txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes,
)
for prev_pdu_id, prev_origin, prev_hashes in pdu.prev_pdus:
for alg, hash_base64 in prev_hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_prev_pdu_hash_txn(
txn, pdu.pdu_id, pdu.origin, prev_pdu_id, prev_origin, alg,
hash_bytes
)
(ref_alg, ref_hash_bytes) = compute_pdu_event_reference_hash(pdu)
self._store_pdu_reference_hash_txn(
txn, pdu.pdu_id, pdu.origin, ref_alg, ref_hash_bytes
)
if pdu.is_state: if pdu.is_state:
self._persist_state_txn(txn, pdu.prev_pdus, cols) self._persist_state_txn(txn, pdu.prev_pdus, cols)
else: else:
@ -338,6 +375,7 @@ class DataStore(RoomMemberStore, RoomStore,
prev_pdus = self._get_latest_pdus_in_context( prev_pdus = self._get_latest_pdus_in_context(
txn, room_id txn, room_id
) )
if state_type is not None and state_key is not None: if state_type is not None and state_key is not None:
prev_state_pdu = self._get_current_state_pdu( prev_state_pdu = self._get_current_state_pdu(
txn, room_id, state_type, state_key txn, room_id, state_type, state_key
@ -387,17 +425,16 @@ class Snapshot(object):
self.prev_state_pdu = prev_state_pdu self.prev_state_pdu = prev_state_pdu
def fill_out_prev_events(self, event): def fill_out_prev_events(self, event):
if hasattr(event, "prev_events"): if hasattr(event, "prev_pdus"):
return return
es = [ event.prev_pdus = [
"%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus (p_id, origin, hashes)
for p_id, origin, hashes, _ in self.prev_pdus
] ]
event.prev_events = [e for e in es if e != event.event_id]
if self.prev_pdus: if self.prev_pdus:
event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1 event.depth = max([int(v) for _, _, _, v in self.prev_pdus]) + 1
else: else:
event.depth = 0 event.depth = 0

View File

@ -104,7 +104,6 @@ class KeyStore(SQLBaseStore):
ts_now_ms (int): The time now in milliseconds ts_now_ms (int): The time now in milliseconds
verification_key (VerifyKey): The NACL verify key. verification_key (VerifyKey): The NACL verify key.
""" """
verify_key_bytes = verify_key.encode()
return self._simple_insert( return self._simple_insert(
table="server_signature_keys", table="server_signature_keys",
values={ values={

View File

@ -20,10 +20,13 @@ from ._base import SQLBaseStore, Table, JoinHelper
from synapse.federation.units import Pdu from synapse.federation.units import Pdu
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from syutil.base64util import encode_base64
from collections import namedtuple from collections import namedtuple
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,6 +67,13 @@ class PduStore(SQLBaseStore):
for r in PduEdgesTable.decode_results(txn.fetchall()) for r in PduEdgesTable.decode_results(txn.fetchall())
] ]
edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin)
hashes = self._get_pdu_content_hashes_txn(txn, pdu_id, origin)
signatures = self._get_pdu_origin_signatures_txn(
txn, pdu_id, origin
)
query = ( query = (
"SELECT %(fields)s FROM %(pdus)s as p " "SELECT %(fields)s FROM %(pdus)s as p "
"LEFT JOIN %(state)s as s " "LEFT JOIN %(state)s as s "
@ -80,7 +90,9 @@ class PduStore(SQLBaseStore):
row = txn.fetchone() row = txn.fetchone()
if row: if row:
results.append(PduTuple(PduEntry(*row), edges)) results.append(PduTuple(
PduEntry(*row), edges, hashes, signatures, edge_hashes
))
return results return results
@ -309,9 +321,14 @@ class PduStore(SQLBaseStore):
(context, ) (context, )
) )
results = txn.fetchall() results = []
for pdu_id, origin, depth in txn.fetchall():
hashes = self._get_pdu_reference_hashes_txn(txn, pdu_id, origin)
sha256_bytes = hashes["sha256"]
prev_hashes = {"sha256": encode_base64(sha256_bytes)}
results.append((pdu_id, origin, prev_hashes, depth))
return [(row[0], row[1], row[2]) for row in results] return results
@defer.inlineCallbacks @defer.inlineCallbacks
def get_oldest_pdus_in_context(self, context): def get_oldest_pdus_in_context(self, context):
@ -430,7 +447,7 @@ class PduStore(SQLBaseStore):
"DELETE FROM %s WHERE pdu_id = ? AND origin = ?" "DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
% PduForwardExtremitiesTable.table_name % PduForwardExtremitiesTable.table_name
) )
txn.executemany(query, prev_pdus) txn.executemany(query, list(p[:2] for p in prev_pdus))
# We only insert as a forward extremety the new pdu if there are no # We only insert as a forward extremety the new pdu if there are no
# other pdus that reference it as a prev pdu # other pdus that reference it as a prev pdu
@ -453,7 +470,7 @@ class PduStore(SQLBaseStore):
# deleted in a second if they're incorrect anyway. # deleted in a second if they're incorrect anyway.
txn.executemany( txn.executemany(
PduBackwardExtremitiesTable.insert_statement(), PduBackwardExtremitiesTable.insert_statement(),
[(i, o, context) for i, o in prev_pdus] [(i, o, context) for i, o, _ in prev_pdus]
) )
# Also delete from the backwards extremities table all ones that # Also delete from the backwards extremities table all ones that
@ -914,7 +931,7 @@ This does not include a prev_pdus key.
PduTuple = namedtuple( PduTuple = namedtuple(
"PduTuple", "PduTuple",
("pdu_entry", "prev_pdu_list") ("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes")
) )
""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent """ This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
the `prev_pdus` key of a PDU. the `prev_pdus` key of a PDU.

View File

@ -0,0 +1,66 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS pdu_content_hashes (
pdu_id TEXT,
origin TEXT,
algorithm TEXT,
hash BLOB,
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, algorithm)
);
CREATE INDEX IF NOT EXISTS pdu_content_hashes_id ON pdu_content_hashes (
pdu_id, origin
);
CREATE TABLE IF NOT EXISTS pdu_reference_hashes (
pdu_id TEXT,
origin TEXT,
algorithm TEXT,
hash BLOB,
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, algorithm)
);
CREATE INDEX IF NOT EXISTS pdu_reference_hashes_id ON pdu_reference_hashes (
pdu_id, origin
);
CREATE TABLE IF NOT EXISTS pdu_origin_signatures (
pdu_id TEXT,
origin TEXT,
key_id TEXT,
signature BLOB,
CONSTRAINT uniqueness UNIQUE (pdu_id, origin, key_id)
);
CREATE INDEX IF NOT EXISTS pdu_origin_signatures_id ON pdu_origin_signatures (
pdu_id, origin
);
CREATE TABLE IF NOT EXISTS pdu_edge_hashes(
pdu_id TEXT,
origin TEXT,
prev_pdu_id TEXT,
prev_origin TEXT,
algorithm TEXT,
hash BLOB,
CONSTRAINT uniqueness UNIQUE (
pdu_id, origin, prev_pdu_id, prev_origin, algorithm
)
);
CREATE INDEX IF NOT EXISTS pdu_edge_hashes_id ON pdu_edge_hashes(
pdu_id, origin
);

View File

@ -0,0 +1,155 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from _base import SQLBaseStore
class SignatureStore(SQLBaseStore):
"""Persistence for PDU signatures and hashes"""
def _get_pdu_content_hashes_txn(self, txn, pdu_id, origin):
"""Get all the hashes for a given PDU.
Args:
txn (cursor):
pdu_id (str): Id for the PDU.
origin (str): origin of the PDU.
Returns:
A dict of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
" FROM pdu_content_hashes"
" WHERE pdu_id = ? and origin = ?"
)
txn.execute(query, (pdu_id, origin))
return dict(txn.fetchall())
def _store_pdu_content_hash_txn(self, txn, pdu_id, origin, algorithm,
hash_bytes):
"""Store a hash for a PDU
Args:
txn (cursor):
pdu_id (str): Id for the PDU.
origin (str): origin of the PDU.
algorithm (str): Hashing algorithm.
hash_bytes (bytes): Hash function output bytes.
"""
self._simple_insert_txn(txn, "pdu_content_hashes", {
"pdu_id": pdu_id,
"origin": origin,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
})
def _get_pdu_reference_hashes_txn(self, txn, pdu_id, origin):
"""Get all the hashes for a given PDU.
Args:
txn (cursor):
pdu_id (str): Id for the PDU.
origin (str): origin of the PDU.
Returns:
A dict of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
" FROM pdu_reference_hashes"
" WHERE pdu_id = ? and origin = ?"
)
txn.execute(query, (pdu_id, origin))
return dict(txn.fetchall())
def _store_pdu_reference_hash_txn(self, txn, pdu_id, origin, algorithm,
hash_bytes):
"""Store a hash for a PDU
Args:
txn (cursor):
pdu_id (str): Id for the PDU.
origin (str): origin of the PDU.
algorithm (str): Hashing algorithm.
hash_bytes (bytes): Hash function output bytes.
"""
self._simple_insert_txn(txn, "pdu_reference_hashes", {
"pdu_id": pdu_id,
"origin": origin,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
})
def _get_pdu_origin_signatures_txn(self, txn, pdu_id, origin):
"""Get all the signatures for a given PDU.
Args:
txn (cursor):
pdu_id (str): Id for the PDU.
origin (str): origin of the PDU.
Returns:
A dict of key_id -> signature_bytes.
"""
query = (
"SELECT key_id, signature"
" FROM pdu_origin_signatures"
" WHERE pdu_id = ? and origin = ?"
)
txn.execute(query, (pdu_id, origin))
return dict(txn.fetchall())
def _store_pdu_origin_signature_txn(self, txn, pdu_id, origin, key_id,
signature_bytes):
"""Store a signature from the origin server for a PDU.
Args:
txn (cursor):
pdu_id (str): Id for the PDU.
origin (str): origin of the PDU.
key_id (str): Id for the signing key.
signature (bytes): The signature.
"""
self._simple_insert_txn(txn, "pdu_origin_signatures", {
"pdu_id": pdu_id,
"origin": origin,
"key_id": key_id,
"signature": buffer(signature_bytes),
})
def _get_prev_pdu_hashes_txn(self, txn, pdu_id, origin):
"""Get all the hashes for previous PDUs of a PDU
Args:
txn (cursor):
pdu_id (str): Id of the PDU.
origin (str): Origin of the PDU.
Returns:
dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
"""
query = (
"SELECT prev_pdu_id, prev_origin, algorithm, hash"
" FROM pdu_edge_hashes"
" WHERE pdu_id = ? and origin = ?"
)
txn.execute(query, (pdu_id, origin))
results = {}
for prev_pdu_id, prev_origin, algorithm, hash_bytes in txn.fetchall():
hashes = results.setdefault((prev_pdu_id, prev_origin), {})
hashes[algorithm] = hash_bytes
return results
def _store_prev_pdu_hash_txn(self, txn, pdu_id, origin, prev_pdu_id,
prev_origin, algorithm, hash_bytes):
self._simple_insert_txn(txn, "pdu_edge_hashes", {
"pdu_id": pdu_id,
"origin": origin,
"prev_pdu_id": prev_pdu_id,
"prev_origin": prev_origin,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
})

1
synapse/test_pyflakes.py Normal file
View File

@ -0,0 +1 @@
import an_unused_module

View File

@ -41,7 +41,7 @@ def make_pdu(prev_pdus=[], **kwargs):
} }
pdu_fields.update(kwargs) pdu_fields.update(kwargs)
return PduTuple(PduEntry(**pdu_fields), prev_pdus) return PduTuple(PduEntry(**pdu_fields), prev_pdus, {}, {}, {})
class FederationTestCase(unittest.TestCase): class FederationTestCase(unittest.TestCase):
@ -183,6 +183,8 @@ class FederationTestCase(unittest.TestCase):
"is_state": False, "is_state": False,
"content": {"testing": "content here"}, "content": {"testing": "content here"},
"depth": 1, "depth": 1,
"hashes": {},
"signatures": {},
}, },
] ]
}, },

View File

@ -23,14 +23,21 @@ from synapse.federation.units import Pdu
from synapse.server import HomeServer from synapse.server import HomeServer
from mock import Mock from mock import Mock, NonCallableMock
from ..utils import MockKey
class PduCodecTestCase(unittest.TestCase): class PduCodecTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.hs = HomeServer("blargle.net") self.mock_config = NonCallableMock()
self.event_factory = self.hs.get_event_factory() self.mock_config.signing_key = [MockKey()]
self.hs = HomeServer(
"blargle.net",
config=self.mock_config,
)
self.event_factory = self.hs.get_event_factory()
self.codec = PduCodec(self.hs) self.codec = PduCodec(self.hs)
def test_decode_event_id(self): def test_decode_event_id(self):
@ -81,7 +88,7 @@ class PduCodecTestCase(unittest.TestCase):
self.assertEquals(pdu.context, event.room_id) self.assertEquals(pdu.context, event.room_id)
self.assertEquals(pdu.is_state, event.is_state) self.assertEquals(pdu.is_state, event.is_state)
self.assertEquals(pdu.depth, event.depth) self.assertEquals(pdu.depth, event.depth)
self.assertEquals(["alice@bob.com"], event.prev_events) self.assertEquals(pdu.prev_pdus, event.prev_pdus)
self.assertEquals(pdu.content, event.content) self.assertEquals(pdu.content, event.content)
def test_pdu_from_event(self): def test_pdu_from_event(self):
@ -137,7 +144,7 @@ class PduCodecTestCase(unittest.TestCase):
self.assertEquals(pdu.context, event.room_id) self.assertEquals(pdu.context, event.room_id)
self.assertEquals(pdu.is_state, event.is_state) self.assertEquals(pdu.is_state, event.is_state)
self.assertEquals(pdu.depth, event.depth) self.assertEquals(pdu.depth, event.depth)
self.assertEquals(["alice@bob.com"], event.prev_events) self.assertEquals(pdu.prev_pdus, event.prev_pdus)
self.assertEquals(pdu.content, event.content) self.assertEquals(pdu.content, event.content)
self.assertEquals(pdu.state_key, event.state_key) self.assertEquals(pdu.state_key, event.state_key)

View File

@ -28,7 +28,7 @@ from synapse.server import HomeServer
# python imports # python imports
import json import json
from ..utils import MockHttpResource, MemoryDataStore from ..utils import MockHttpResource, MemoryDataStore, MockKey
from .utils import RestTestCase from .utils import RestTestCase
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
@ -122,6 +122,9 @@ class EventStreamPermissionsTestCase(RestTestCase):
persistence_service = Mock(spec=["get_latest_pdus_in_context"]) persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = [] persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"test", "test",
db_pool=None, db_pool=None,
@ -139,7 +142,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
config=NonCallableMock(), config=self.mock_config,
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)

View File

@ -18,9 +18,9 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from mock import Mock from mock import Mock, NonCallableMock
from ..utils import MockHttpResource from ..utils import MockHttpResource, MockKey
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.server import HomeServer from synapse.server import HomeServer
@ -41,6 +41,9 @@ class ProfileTestCase(unittest.TestCase):
"set_avatar_url", "set_avatar_url",
]) ])
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer("test", hs = HomeServer("test",
db_pool=None, db_pool=None,
http_client=None, http_client=None,
@ -48,6 +51,7 @@ class ProfileTestCase(unittest.TestCase):
federation=Mock(), federation=Mock(),
replication_layer=Mock(), replication_layer=Mock(),
datastore=None, datastore=None,
config=self.mock_config,
) )
def _get_user_by_req(request=None): def _get_user_by_req(request=None):

View File

@ -27,7 +27,7 @@ from synapse.server import HomeServer
import json import json
import urllib import urllib
from ..utils import MockHttpResource, MemoryDataStore from ..utils import MockHttpResource, MemoryDataStore, MockKey
from .utils import RestTestCase from .utils import RestTestCase
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
@ -50,6 +50,9 @@ class RoomPermissionsTestCase(RestTestCase):
persistence_service = Mock(spec=["get_latest_pdus_in_context"]) persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = [] persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"red", "red",
db_pool=None, db_pool=None,
@ -61,7 +64,7 @@ class RoomPermissionsTestCase(RestTestCase):
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
config=NonCallableMock(), config=self.mock_config,
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
@ -408,6 +411,9 @@ class RoomsMemberListTestCase(RestTestCase):
persistence_service = Mock(spec=["get_latest_pdus_in_context"]) persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = [] persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"red", "red",
db_pool=None, db_pool=None,
@ -419,7 +425,7 @@ class RoomsMemberListTestCase(RestTestCase):
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
config=NonCallableMock(), config=self.mock_config,
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
@ -497,6 +503,9 @@ class RoomsCreateTestCase(RestTestCase):
persistence_service = Mock(spec=["get_latest_pdus_in_context"]) persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = [] persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"red", "red",
db_pool=None, db_pool=None,
@ -508,7 +517,7 @@ class RoomsCreateTestCase(RestTestCase):
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
config=NonCallableMock(), config=self.mock_config,
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
@ -598,6 +607,9 @@ class RoomTopicTestCase(RestTestCase):
persistence_service = Mock(spec=["get_latest_pdus_in_context"]) persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = [] persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"red", "red",
db_pool=None, db_pool=None,
@ -609,7 +621,7 @@ class RoomTopicTestCase(RestTestCase):
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
config=NonCallableMock(), config=self.mock_config,
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
@ -712,6 +724,9 @@ class RoomMemberStateTestCase(RestTestCase):
persistence_service = Mock(spec=["get_latest_pdus_in_context"]) persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = [] persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"red", "red",
db_pool=None, db_pool=None,
@ -723,7 +738,7 @@ class RoomMemberStateTestCase(RestTestCase):
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
config=NonCallableMock(), config=self.mock_config,
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
@ -853,6 +868,9 @@ class RoomMessagesTestCase(RestTestCase):
persistence_service = Mock(spec=["get_latest_pdus_in_context"]) persistence_service = Mock(spec=["get_latest_pdus_in_context"])
persistence_service.get_latest_pdus_in_context.return_value = [] persistence_service.get_latest_pdus_in_context.return_value = []
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
hs = HomeServer( hs = HomeServer(
"red", "red",
db_pool=None, db_pool=None,
@ -864,7 +882,7 @@ class RoomMessagesTestCase(RestTestCase):
ratelimiter=NonCallableMock(spec_set=[ ratelimiter=NonCallableMock(spec_set=[
"send_message", "send_message",
]), ]),
config=NonCallableMock(), config=self.mock_config,
) )
self.ratelimiter = hs.get_ratelimiter() self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)

View File

@ -118,13 +118,14 @@ class MockHttpResource(HttpServer):
class MockKey(object): class MockKey(object):
alg = "mock_alg" alg = "mock_alg"
version = "mock_version" version = "mock_version"
signature = b"\x9a\x87$"
@property @property
def verify_key(self): def verify_key(self):
return self return self
def sign(self, message): def sign(self, message):
return b"\x9a\x87$" return self
def verify(self, message, sig): def verify(self, message, sig):
assert sig == b"\x9a\x87$" assert sig == b"\x9a\x87$"