Begin implementing all the PDU storage stuff in Events land

This commit is contained in:
Erik Johnston 2014-10-28 16:42:35 +00:00
parent da1dda3e1d
commit 2d1dfb3b34
9 changed files with 485 additions and 42 deletions

View File

@ -71,7 +71,9 @@ class SynapseEvent(JsonEncodedObject):
"outlier", "outlier",
"power_level", "power_level",
"redacted", "redacted",
"prev_pdus", "prev_events",
"hashes",
"signatures",
] ]
required_keys = [ required_keys = [

View File

@ -47,7 +47,10 @@ class PduCodec(object):
kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin) kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin)
kwargs["room_id"] = pdu.context kwargs["room_id"] = pdu.context
kwargs["etype"] = pdu.pdu_type kwargs["etype"] = pdu.pdu_type
kwargs["prev_pdus"] = pdu.prev_pdus kwargs["prev_events"] = [
encode_event_id(i, o)
for i, o in pdu.prev_pdus
]
if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"): if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"):
kwargs["prev_state"] = encode_event_id( kwargs["prev_state"] = encode_event_id(
@ -78,8 +81,8 @@ class PduCodec(object):
d["context"] = event.room_id d["context"] = event.room_id
d["pdu_type"] = event.type d["pdu_type"] = event.type
if hasattr(event, "prev_pdus"): if hasattr(event, "prev_events"):
d["prev_pdus"] = event.prev_pdus d["prev_pdus"] = [decode_event_id(e) for e in event.prev_events]
if hasattr(event, "prev_state"): if hasattr(event, "prev_state"):
d["prev_state_id"], d["prev_state_origin"] = ( d["prev_state_id"], d["prev_state_origin"] = (
@ -92,7 +95,7 @@ class PduCodec(object):
kwargs = copy.deepcopy(event.unrecognized_keys) kwargs = copy.deepcopy(event.unrecognized_keys)
kwargs.update({ kwargs.update({
k: v for k, v in d.items() k: v for k, v in d.items()
if k not in ["event_id", "room_id", "type"] if k not in ["event_id", "room_id", "type", "prev_events"]
}) })
if "origin_server_ts" not in kwargs: if "origin_server_ts" not in kwargs:

View File

@ -40,6 +40,7 @@ from .stream import StreamStore
from .pdu import StatePduStore, PduStore, PdusTable from .pdu import StatePduStore, PduStore, PdusTable
from .transactions import TransactionStore from .transactions import TransactionStore
from .keys import KeyStore from .keys import KeyStore
from .event_federation import EventFederationStore
from .state import StateStore from .state import StateStore
from .signatures import SignatureStore from .signatures import SignatureStore
@ -69,6 +70,7 @@ SCHEMAS = [
"redactions", "redactions",
"state", "state",
"signatures", "signatures",
"event_edges",
] ]
@ -83,10 +85,12 @@ class _RollbackButIsFineException(Exception):
""" """
pass pass
class DataStore(RoomMemberStore, RoomStore, class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
PresenceStore, PduStore, StatePduStore, TransactionStore, PresenceStore, PduStore, StatePduStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore): DirectoryStore, KeyStore, StateStore, SignatureStore,
EventFederationStore, ):
def __init__(self, hs): def __init__(self, hs):
super(DataStore, self).__init__(hs) super(DataStore, self).__init__(hs)
@ -230,6 +234,10 @@ class DataStore(RoomMemberStore, RoomStore,
elif event.type == RoomRedactionEvent.TYPE: elif event.type == RoomRedactionEvent.TYPE:
self._store_redaction(txn, event) self._store_redaction(txn, event)
outlier = False
if hasattr(event, "outlier"):
outlier = event.outlier
vals = { vals = {
"topological_ordering": event.depth, "topological_ordering": event.depth,
"event_id": event.event_id, "event_id": event.event_id,
@ -237,20 +245,20 @@ class DataStore(RoomMemberStore, RoomStore,
"room_id": event.room_id, "room_id": event.room_id,
"content": json.dumps(event.content), "content": json.dumps(event.content),
"processed": True, "processed": True,
"outlier": outlier,
"depth": event.depth,
} }
if stream_ordering is not None: if stream_ordering is not None:
vals["stream_ordering"] = stream_ordering vals["stream_ordering"] = stream_ordering
if hasattr(event, "outlier"):
vals["outlier"] = event.outlier
else:
vals["outlier"] = False
unrec = { unrec = {
k: v k: v
for k, v in event.get_full_dict().items() for k, v in event.get_full_dict().items()
if k not in vals.keys() and k not in ["redacted", "redacted_because"] if k not in vals.keys() and k not in [
"redacted", "redacted_because", "signatures", "hashes",
"prev_events",
]
} }
vals["unrecognized_keys"] = json.dumps(unrec) vals["unrecognized_keys"] = json.dumps(unrec)
@ -264,6 +272,14 @@ class DataStore(RoomMemberStore, RoomStore,
) )
raise _RollbackButIsFineException("_persist_event") raise _RollbackButIsFineException("_persist_event")
self._handle_prev_events(
txn,
outlier=outlier,
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
)
self._store_state_groups_txn(txn, event) self._store_state_groups_txn(txn, event)
is_state = hasattr(event, "state_key") and event.state_key is not None is_state = hasattr(event, "state_key") and event.state_key is not None
@ -291,6 +307,28 @@ class DataStore(RoomMemberStore, RoomStore,
} }
) )
signatures = event.signatures.get(event.origin, {})
for key_id, signature_base64 in signatures.items():
signature_bytes = decode_base64(signature_base64)
self._store_event_origin_signature_txn(
txn, event.event_id, key_id, signature_bytes,
)
for prev_event_id, prev_hashes in event.prev_events:
for alg, hash_base64 in prev_hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_prev_event_hash_txn(
txn, event.event_id, prev_event_id, alg, hash_bytes
)
(ref_alg, ref_hash_bytes) = compute_pdu_event_reference_hash(pdu)
self._store_pdu_reference_hash_txn(
txn, pdu.pdu_id, pdu.origin, ref_alg, ref_hash_bytes
)
self._update_min_depth_for_room_txn(txn, event.room_id, event.depth)
def _store_redaction(self, txn, event): def _store_redaction(self, txn, event):
txn.execute( txn.execute(
"INSERT OR IGNORE INTO redactions " "INSERT OR IGNORE INTO redactions "
@ -373,7 +411,7 @@ class DataStore(RoomMemberStore, RoomStore,
""" """
def _snapshot(txn): def _snapshot(txn):
membership_state = self._get_room_member(txn, user_id, room_id) membership_state = self._get_room_member(txn, user_id, room_id)
prev_pdus = self._get_latest_pdus_in_context( prev_events = self._get_latest_events_in_room(
txn, room_id txn, room_id
) )
@ -388,7 +426,7 @@ class DataStore(RoomMemberStore, RoomStore,
store=self, store=self,
room_id=room_id, room_id=room_id,
user_id=user_id, user_id=user_id,
prev_pdus=prev_pdus, prev_events=prev_events,
membership_state=membership_state, membership_state=membership_state,
state_type=state_type, state_type=state_type,
state_key=state_key, state_key=state_key,
@ -404,7 +442,7 @@ class Snapshot(object):
store (DataStore): The datastore. store (DataStore): The datastore.
room_id (RoomId): The room of the snapshot. room_id (RoomId): The room of the snapshot.
user_id (UserId): The user this snapshot is for. user_id (UserId): The user this snapshot is for.
prev_pdus (list): The list of PDU ids this snapshot is after. prev_events (list): The list of event ids this snapshot is after.
membership_state (RoomMemberEvent): The current state of the user in membership_state (RoomMemberEvent): The current state of the user in
the room. the room.
state_type (str, optional): State type captured by the snapshot state_type (str, optional): State type captured by the snapshot
@ -413,29 +451,29 @@ class Snapshot(object):
the previous value of the state type and key in the room. the previous value of the state type and key in the room.
""" """
def __init__(self, store, room_id, user_id, prev_pdus, def __init__(self, store, room_id, user_id, prev_events,
membership_state, state_type=None, state_key=None, membership_state, state_type=None, state_key=None,
prev_state_pdu=None): prev_state_pdu=None):
self.store = store self.store = store
self.room_id = room_id self.room_id = room_id
self.user_id = user_id self.user_id = user_id
self.prev_pdus = prev_pdus self.prev_events = prev_events
self.membership_state = membership_state self.membership_state = membership_state
self.state_type = state_type self.state_type = state_type
self.state_key = state_key self.state_key = state_key
self.prev_state_pdu = prev_state_pdu self.prev_state_pdu = prev_state_pdu
def fill_out_prev_events(self, event): def fill_out_prev_events(self, event):
if hasattr(event, "prev_pdus"): if hasattr(event, "prev_events"):
return return
event.prev_pdus = [ event.prev_events = [
(p_id, origin, hashes) (p_id, origin, hashes)
for p_id, origin, hashes, _ in self.prev_pdus for p_id, origin, hashes, _ in self.prev_events
] ]
if self.prev_pdus: if self.prev_events:
event.depth = max([int(v) for _, _, _, v in self.prev_pdus]) + 1 event.depth = max([int(v) for _, _, _, v in self.prev_events]) + 1
else: else:
event.depth = 0 event.depth = 0

View File

@ -193,7 +193,6 @@ class SQLBaseStore(object):
table, keyvalues, retcols=retcols, allow_none=allow_none table, keyvalues, retcols=retcols, allow_none=allow_none
) )
@defer.inlineCallbacks
def _simple_select_one_onecol(self, table, keyvalues, retcol, def _simple_select_one_onecol(self, table, keyvalues, retcol,
allow_none=False): allow_none=False):
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
@ -204,19 +203,41 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the row with keyvalues : dict of column names and values to select the row with
retcol : string giving the name of the column to return retcol : string giving the name of the column to return
""" """
ret = yield self._simple_select_one( return self.runInteraction(
"_simple_select_one_onecol_txn",
self._simple_select_one_onecol_txn,
table, keyvalues, retcol, allow_none=allow_none,
)
def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
allow_none=False):
ret = self._simple_select_onecol_txn(
txn,
table=table, table=table,
keyvalues=keyvalues, keyvalues=keyvalues,
retcols=[retcol], retcols=retcol,
allow_none=allow_none
) )
if ret: if ret:
defer.returnValue(ret[retcol]) return ret[retcol]
else: else:
defer.returnValue(None) if allow_none:
return None
else:
raise StoreError(404, "No row found")
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
"retcol": retcol,
"table": table,
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
}
txn.execute(sql, keyvalues.values())
return [r[0] for r in txn.fetchall()]
@defer.inlineCallbacks
def _simple_select_onecol(self, table, keyvalues, retcol): def _simple_select_onecol(self, table, keyvalues, retcol):
"""Executes a SELECT query on the named table, which returns a list """Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows. comprising of the values of the named column from the selected rows.
@ -229,19 +250,11 @@ class SQLBaseStore(object):
Returns: Returns:
Deferred: Results in a list Deferred: Results in a list
""" """
sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % { return self.runInteraction(
"retcol": retcol, "_simple_select_onecol",
"table": table, self._simple_select_onecol_txn,
"where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), table, keyvalues, retcol
} )
def func(txn):
txn.execute(sql, keyvalues.values())
return txn.fetchall()
res = yield self.runInteraction("_simple_select_onecol", func)
defer.returnValue([r[0] for r in res])
def _simple_select_list(self, table, keyvalues, retcols): def _simple_select_list(self, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or

View File

@ -0,0 +1,143 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class EventFederationStore(SQLBaseStore):
def _get_latest_events_in_room(self, txn, room_id):
self._simple_select_onecol_txn(
txn,
table="event_forward_extremities",
keyvalues={
"room_id": room_id,
},
retcol="event_id",
)
results = []
for pdu_id, origin, depth in txn.fetchall():
hashes = self._get_pdu_reference_hashes_txn(txn, pdu_id, origin)
sha256_bytes = hashes["sha256"]
prev_hashes = {"sha256": encode_base64(sha256_bytes)}
results.append((pdu_id, origin, prev_hashes, depth))
def _get_min_depth_interaction(self, txn, room_id):
min_depth = self._simple_select_one_onecol_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id,},
retcol="min_depth",
allow_none=True,
)
return int(min_depth) if min_depth is not None else None
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
min_depth = self._get_min_depth_interaction(txn, room_id)
do_insert = depth < min_depth if min_depth else True
if do_insert:
self._simple_insert_txn(
txn,
table="room_depth",
values={
"room_id": room_id,
"min_depth": depth,
},
or_replace=True,
)
def _handle_prev_events(self, txn, outlier, event_id, prev_events,
room_id):
for e_id in prev_events:
# TODO (erikj): This could be done as a bulk insert
self._simple_insert_txn(
txn,
table="event_edges",
values={
"event_id": event_id,
"prev_event": e_id,
"room_id": room_id,
}
)
# Update the extremities table if this is not an outlier.
if not outlier:
for e_id in prev_events:
# TODO (erikj): This could be done as a bulk insert
self._simple_delete_txn(
txn,
table="event_forward_extremities",
keyvalues={
"event_id": e_id,
"room_id": room_id,
}
)
# We only insert as a forward extremity the new pdu if there are no
# other pdus that reference it as a prev pdu
query = (
"INSERT INTO %(table)s (event_id, room_id) "
"SELECT ?, ? WHERE NOT EXISTS ("
"SELECT 1 FROM %(event_edges)s WHERE "
"prev_event_id = ? "
")"
) % {
"table": "event_forward_extremities",
"event_edges": "event_edges",
}
logger.debug("query: %s", query)
txn.execute(query, (event_id, room_id, event_id))
# Insert all the prev_pdus as a backwards thing, they'll get
# deleted in a second if they're incorrect anyway.
for e_id in prev_events:
# TODO (erikj): This could be done as a bulk insert
self._simple_insert_txn(
txn,
table="event_backward_extremities",
values={
"event_id": e_id,
"room_id": room_id,
}
)
# Also delete from the backwards extremities table all ones that
# reference pdus that we have already seen
query = (
"DELETE FROM %(event_back)s as b WHERE EXISTS ("
"SELECT 1 FROM %(events)s AS events "
"WHERE "
"b.event_id = events.event_id "
"AND not events.outlier "
")"
) % {
"event_back": "event_backward_extremities",
"events": "events",
}
txn.execute(query)

View File

@ -0,0 +1,51 @@
CREATE TABLE IF NOT EXISTS event_forward_extremities(
event_id TEXT,
room_id TEXT,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
);
CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id);
CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id);
--
CREATE TABLE IF NOT EXISTS event_backward_extremities(
event_id TEXT,
room_id TEXT,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
);
CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id);
CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id);
--
CREATE TABLE IF NOT EXISTS event_edges(
event_id TEXT,
prev_event_id TEXT,
room_id TEXT,
CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id)
);
CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id);
CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id);
--
CREATE TABLE IF NOT EXISTS room_depth(
room_id TEXT,
min_depth INTEGER,
CONSTRAINT uniqueness UNIQUE (room_id)
);
CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id);
--
create TABLE IF NOT EXISTS event_destinations(
event_id TEXT,
destination TEXT,
delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE
);
CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id);
--

View File

@ -0,0 +1,65 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS event_content_hashes (
event_id TEXT,
algorithm TEXT,
hash BLOB,
CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
);
CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes(
event_id
);
CREATE TABLE IF NOT EXISTS event_reference_hashes (
event_id TEXT,
algorithm TEXT,
hash BLOB,
CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
);
CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes (
event_id
);
CREATE TABLE IF NOT EXISTS event_origin_signatures (
event_id TEXT,
origin TEXT,
key_id TEXT,
signature BLOB,
CONSTRAINT uniqueness UNIQUE (event_id, key_id)
);
CREATE INDEX IF NOT EXISTS event_origin_signatures_id ON event_origin_signatures (
event_id
);
CREATE TABLE IF NOT EXISTS event_edge_hashes(
event_id TEXT,
prev_event_id TEXT,
algorithm TEXT,
hash BLOB,
CONSTRAINT uniqueness UNIQUE (
event_id, prev_event_id, algorithm
)
);
CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes(
event_id
);

View File

@ -23,6 +23,7 @@ CREATE TABLE IF NOT EXISTS events(
unrecognized_keys TEXT, unrecognized_keys TEXT,
processed BOOL NOT NULL, processed BOOL NOT NULL,
outlier BOOL NOT NULL, outlier BOOL NOT NULL,
depth INTEGER DEFAULT 0 NOT NULL,
CONSTRAINT ev_uniq UNIQUE (event_id) CONSTRAINT ev_uniq UNIQUE (event_id)
); );

View File

@ -153,3 +153,130 @@ class SignatureStore(SQLBaseStore):
"algorithm": algorithm, "algorithm": algorithm,
"hash": buffer(hash_bytes), "hash": buffer(hash_bytes),
}) })
## Events ##
def _get_event_content_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given Event.
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
A dict of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
" FROM event_content_hashes"
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
return dict(txn.fetchall())
def _store_event_content_hash_txn(self, txn, event_id, algorithm,
hash_bytes):
"""Store a hash for a Event
Args:
txn (cursor):
event_id (str): Id for the Event.
algorithm (str): Hashing algorithm.
hash_bytes (bytes): Hash function output bytes.
"""
self._simple_insert_txn(txn, "event_content_hashes", {
"event_id": event_id,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
})
def _get_event_reference_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given PDU.
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
A dict of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
" FROM event_reference_hashes"
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
return dict(txn.fetchall())
def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
hash_bytes):
"""Store a hash for a PDU
Args:
txn (cursor):
event_id (str): Id for the Event.
algorithm (str): Hashing algorithm.
hash_bytes (bytes): Hash function output bytes.
"""
self._simple_insert_txn(txn, "event_reference_hashes", {
"event_id": event_id,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
})
def _get_event_origin_signatures_txn(self, txn, event_id):
"""Get all the signatures for a given PDU.
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
A dict of key_id -> signature_bytes.
"""
query = (
"SELECT key_id, signature"
" FROM event_origin_signatures"
" WHERE event_id = ? "
)
txn.execute(query, (event_id, ))
return dict(txn.fetchall())
def _store_event_origin_signature_txn(self, txn, event_id, origin, key_id,
signature_bytes):
"""Store a signature from the origin server for a PDU.
Args:
txn (cursor):
event_id (str): Id for the Event.
origin (str): origin of the Event.
key_id (str): Id for the signing key.
signature (bytes): The signature.
"""
self._simple_insert_txn(txn, "event_origin_signatures", {
"event_id": event_id,
"origin": origin,
"key_id": key_id,
"signature": buffer(signature_bytes),
})
def _get_prev_event_hashes_txn(self, txn, event_id):
"""Get all the hashes for previous PDUs of a PDU
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
"""
query = (
"SELECT prev_event_id, algorithm, hash"
" FROM event_edge_hashes"
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
results = {}
for prev_event_id, algorithm, hash_bytes in txn.fetchall():
hashes = results.setdefault(prev_event_id, {})
hashes[algorithm] = hash_bytes
return results
def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id,
algorithm, hash_bytes):
self._simple_insert_txn(txn, "event_edge_hashes", {
"event_id": event_id,
"prev_event_id": prev_event_id,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
})