Implement new replace_state and changed prev_state

`prev_state` is now a list of previous state ids, similiar to
prev_events. `replace_state` now points to what we think was replaced.
This commit is contained in:
Erik Johnston 2014-11-06 15:10:55 +00:00
parent 3791b75000
commit 4317c8e583
13 changed files with 219 additions and 127 deletions

View File

@ -60,6 +60,7 @@ class SynapseEvent(JsonEncodedObject):
"age_ts", "age_ts",
"prev_content", "prev_content",
"prev_state", "prev_state",
"replaces_state",
"redacted_because", "redacted_because",
"origin_server_ts", "origin_server_ts",
] ]

View File

@ -147,10 +147,7 @@ class DirectoryHandler(BaseHandler):
content={"aliases": aliases}, content={"aliases": aliases},
) )
snapshot = yield self.store.snapshot_room( snapshot = yield self.store.snapshot_room(event)
room_id=room_id,
user_id=user_id,
)
yield self._on_new_room_event( yield self._on_new_room_event(
event, snapshot, extra_users=[user_id], suppress_auth=True event, snapshot, extra_users=[user_id], suppress_auth=True

View File

@ -313,9 +313,7 @@ class FederationHandler(BaseHandler):
state_key=user_id, state_key=user_id,
) )
snapshot = yield self.store.snapshot_room( snapshot = yield self.store.snapshot_room(event)
event.room_id, event.user_id,
)
snapshot.fill_out_prev_events(event) snapshot.fill_out_prev_events(event)
yield self.state_handler.annotate_state_groups(event) yield self.state_handler.annotate_state_groups(event)

View File

@ -81,7 +81,7 @@ class MessageHandler(BaseHandler):
user = self.hs.parse_userid(event.user_id) user = self.hs.parse_userid(event.user_id)
assert user.is_mine, "User must be our own: %s" % (user,) assert user.is_mine, "User must be our own: %s" % (user,)
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) snapshot = yield self.store.snapshot_room(event)
yield self._on_new_room_event( yield self._on_new_room_event(
event, snapshot, suppress_auth=suppress_auth event, snapshot, suppress_auth=suppress_auth
@ -141,12 +141,7 @@ class MessageHandler(BaseHandler):
SynapseError if something went wrong. SynapseError if something went wrong.
""" """
snapshot = yield self.store.snapshot_room( snapshot = yield self.store.snapshot_room(event)
event.room_id,
event.user_id,
state_type=event.type,
state_key=event.state_key,
)
yield self._on_new_room_event(event, snapshot) yield self._on_new_room_event(event, snapshot)
@ -214,7 +209,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_feedback(self, event): def send_feedback(self, event):
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id) snapshot = yield self.store.snapshot_room(event)
# store message in db # store message in db
yield self._on_new_room_event(event, snapshot) yield self._on_new_room_event(event, snapshot)

View File

@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.api.events.room import RoomMemberEvent
from ._base import BaseHandler from ._base import BaseHandler
@ -196,10 +195,7 @@ class ProfileHandler(BaseHandler):
) )
for j in joins: for j in joins:
snapshot = yield self.store.snapshot_room( snapshot = yield self.store.snapshot_room(j)
j.room_id, j.state_key, RoomMemberEvent.TYPE,
j.state_key
)
content = { content = {
"membership": j.content["membership"], "membership": j.content["membership"],

View File

@ -122,10 +122,7 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_event(event): def handle_event(event):
snapshot = yield self.store.snapshot_room( snapshot = yield self.store.snapshot_room(event)
room_id=room_id,
user_id=user_id,
)
logger.debug("Event: %s", event) logger.debug("Event: %s", event)
@ -364,10 +361,8 @@ class RoomMemberHandler(BaseHandler):
""" """
target_user_id = event.state_key target_user_id = event.state_key
snapshot = yield self.store.snapshot_room( snapshot = yield self.store.snapshot_room(event)
event.room_id, event.user_id,
RoomMemberEvent.TYPE, target_user_id
)
## TODO(markjh): get prev state from snapshot. ## TODO(markjh): get prev state from snapshot.
prev_state = yield self.store.get_room_member( prev_state = yield self.store.get_room_member(
target_user_id, event.room_id target_user_id, event.room_id
@ -442,10 +437,7 @@ class RoomMemberHandler(BaseHandler):
content=content, content=content,
) )
snapshot = yield self.store.snapshot_room( snapshot = yield self.store.snapshot_room(new_event)
room_id, joinee.to_string(), RoomMemberEvent.TYPE,
joinee.to_string()
)
yield self._do_join(new_event, snapshot, room_host=host, do_auth=True) yield self._do_join(new_event, snapshot, room_host=host, do_auth=True)

View File

@ -138,7 +138,7 @@ class RoomStateEventRestServlet(RestServlet):
raise SynapseError( raise SynapseError(
404, "Event not found.", errcode=Codes.NOT_FOUND 404, "Event not found.", errcode=Codes.NOT_FOUND
) )
defer.returnValue((200, data[0].get_dict()["content"])) defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key): def on_PUT(self, request, room_id, event_type, state_key):

View File

@ -45,40 +45,6 @@ class StateHandler(object):
self.server_name = hs.hostname self.server_name = hs.hostname
self.hs = hs self.hs = hs
@defer.inlineCallbacks
@log_function
def handle_new_event(self, event, snapshot):
""" Given an event this works out if a) we have sufficient power level
to update the state and b) works out what the prev_state should be.
Returns:
Deferred: Resolved with a boolean indicating if we successfully
updated the state.
Raised:
AuthError
"""
# This needs to be done in a transaction.
if not hasattr(event, "state_key"):
return
# Now I need to fill out the prev state and work out if it has auth
# (w.r.t. to power levels)
snapshot.fill_out_prev_events(event)
yield self.annotate_state_groups(event)
if event.old_state_events:
current_state = event.old_state_events.get(
(event.type, event.state_key)
)
if current_state:
event.prev_state = current_state.event_id
defer.returnValue(True)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def annotate_state_groups(self, event, old_state=None): def annotate_state_groups(self, event, old_state=None):
@ -111,7 +77,10 @@ class StateHandler(object):
event.old_state_events = copy.deepcopy(new_state) event.old_state_events = copy.deepcopy(new_state)
if hasattr(event, "state_key"): if hasattr(event, "state_key"):
new_state[(event.type, event.state_key)] = event key = (event.type, event.state_key)
if key in new_state:
event.replaces_state = new_state[key].event_id
new_state[key] = event
event.state_group = None event.state_group = None
event.state_events = new_state event.state_events = new_state

View File

@ -242,8 +242,8 @@ class DataStore(RoomMemberStore, RoomStore,
"state_key": event.state_key, "state_key": event.state_key,
} }
if hasattr(event, "prev_state"): if hasattr(event, "replaces_state"):
vals["prev_state"] = event.prev_state vals["prev_state"] = event.replaces_state
self._simple_insert_txn(txn, "state_events", vals) self._simple_insert_txn(txn, "state_events", vals)
@ -258,6 +258,40 @@ class DataStore(RoomMemberStore, RoomStore,
} }
) )
for e_id, h in event.prev_state:
self._simple_insert_txn(
txn,
table="event_edges",
values={
"event_id": event.event_id,
"prev_event_id": e_id,
"room_id": event.room_id,
"is_state": 1,
},
or_ignore=True,
)
if not backfilled:
self._simple_insert_txn(
txn,
table="state_forward_extremities",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
}
)
for prev_state_id, _ in event.prev_state:
self._simple_delete_txn(
txn,
table="state_forward_extremities",
keyvalues={
"event_id": prev_state_id,
}
)
for hash_alg, hash_base64 in event.hashes.items(): for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64) hash_bytes = decode_base64(hash_base64)
self._store_event_content_hash_txn( self._store_event_content_hash_txn(
@ -357,7 +391,7 @@ class DataStore(RoomMemberStore, RoomStore,
], ],
) )
def snapshot_room(self, room_id, user_id, state_type=None, state_key=None): def snapshot_room(self, event):
"""Snapshot the room for an update by a user """Snapshot the room for an update by a user
Args: Args:
room_id (synapse.types.RoomId): The room to snapshot. room_id (synapse.types.RoomId): The room to snapshot.
@ -368,16 +402,29 @@ class DataStore(RoomMemberStore, RoomStore,
synapse.storage.Snapshot: A snapshot of the state of the room. synapse.storage.Snapshot: A snapshot of the state of the room.
""" """
def _snapshot(txn): def _snapshot(txn):
membership_state = self._get_room_member(txn, user_id, room_id) prev_events = self._get_latest_events_in_room(
prev_events = self._get_latest_events_in_room(txn, room_id) txn,
event.room_id
)
prev_state = None
state_key = None
if hasattr(event, "state_key"):
state_key = event.state_key
prev_state = self._get_latest_state_in_room(
txn,
event.room_id,
type=event.type,
state_key=state_key,
)
return Snapshot( return Snapshot(
store=self, store=self,
room_id=room_id, room_id=event.room_id,
user_id=user_id, user_id=event.user_id,
prev_events=prev_events, prev_events=prev_events,
membership_state=membership_state, prev_state=prev_state,
state_type=state_type, state_type=event.type,
state_key=state_key, state_key=state_key,
) )
@ -400,30 +447,29 @@ class Snapshot(object):
""" """
def __init__(self, store, room_id, user_id, prev_events, def __init__(self, store, room_id, user_id, prev_events,
membership_state, state_type=None, state_key=None, prev_state, state_type=None, state_key=None):
prev_state_pdu=None):
self.store = store self.store = store
self.room_id = room_id self.room_id = room_id
self.user_id = user_id self.user_id = user_id
self.prev_events = prev_events self.prev_events = prev_events
self.membership_state = membership_state self.prev_state = prev_state
self.state_type = state_type self.state_type = state_type
self.state_key = state_key self.state_key = state_key
self.prev_state_pdu = prev_state_pdu
def fill_out_prev_events(self, event): def fill_out_prev_events(self, event):
if hasattr(event, "prev_events"): if not hasattr(event, "prev_events"):
return event.prev_events = [
(event_id, hashes)
for event_id, hashes, _ in self.prev_events
]
event.prev_events = [ if self.prev_events:
(event_id, hashes) event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
for event_id, hashes, _ in self.prev_events else:
] event.depth = 0
if self.prev_events: if not hasattr(event, "prev_state") and self.prev_state is not None:
event.depth = max([int(v) for _, _, v in self.prev_events]) + 1 event.prev_state = self.prev_state
else:
event.depth = 0
def schema_path(schema): def schema_path(schema):

View File

@ -245,7 +245,6 @@ class SQLBaseStore(object):
return [r[0] for r in txn.fetchall()] return [r[0] for r in txn.fetchall()]
def _simple_select_onecol(self, table, keyvalues, retcol): def _simple_select_onecol(self, table, keyvalues, retcol):
"""Executes a SELECT query on the named table, which returns a list """Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows. comprising of the values of the named column from the selected rows.
@ -273,17 +272,30 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the rows with keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return retcols : list of strings giving the names of the columns to return
""" """
return self.runInteraction(
"_simple_select_list",
self._simple_select_list_txn,
table, keyvalues, retcols
)
def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
txn : Transaction object
table : string giving the table name
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
sql = "SELECT %s FROM %s WHERE %s" % ( sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols), ", ".join(retcols),
table, table,
" AND ".join("%s = ?" % (k) for k in keyvalues) " AND ".join("%s = ?" % (k, ) for k in keyvalues)
) )
def func(txn): txn.execute(sql, keyvalues.values())
txn.execute(sql, keyvalues.values()) return self.cursor_to_dict(txn)
return self.cursor_to_dict(txn)
return self.runInteraction("_simple_select_list", func)
def _simple_update_one(self, table, keyvalues, updatevalues, def _simple_update_one(self, table, keyvalues, updatevalues,
retcols=None): retcols=None):
@ -417,6 +429,10 @@ class SQLBaseStore(object):
d.pop("topological_ordering", None) d.pop("topological_ordering", None)
d.pop("processed", None) d.pop("processed", None)
d["origin_server_ts"] = d.pop("ts", 0) d["origin_server_ts"] = d.pop("ts", 0)
replaces_state = d.pop("prev_state", None)
if replaces_state:
d["replaces_state"] = replaces_state
d.update(json.loads(row_dict["unrecognized_keys"])) d.update(json.loads(row_dict["unrecognized_keys"]))
d["content"] = json.loads(d["content"]) d["content"] = json.loads(d["content"])
@ -450,16 +466,32 @@ class SQLBaseStore(object):
k: encode_base64(v) for k, v in signatures.items() k: encode_base64(v) for k, v in signatures.items()
} }
ev.prev_events = self._get_prev_events(txn, ev.event_id) prevs = self._get_prev_events_and_state(txn, ev.event_id)
if hasattr(ev, "prev_state"): ev.prev_events = [
# Load previous state_content. (e_id, h)
# TODO: Should we be pulling this out above? for e_id, h, is_state in prevs
cursor = txn.execute(select_event_sql, (ev.prev_state,)) if is_state == 0
prevs = self.cursor_to_dict(cursor) ]
if prevs:
prev = self._parse_event_from_row(prevs[0]) if hasattr(ev, "state_key"):
ev.prev_content = prev.content ev.prev_state = [
(e_id, h)
for e_id, h, is_state in prevs
if is_state == 1
]
if hasattr(ev, "replaces_state"):
# Load previous state_content.
# FIXME (erikj): Handle multiple prev_states.
cursor = txn.execute(
select_event_sql,
(ev.replaces_state,)
)
prevs = self.cursor_to_dict(cursor)
if prevs:
prev = self._parse_event_from_row(prevs[0])
ev.prev_content = prev.content
if not hasattr(ev, "redacted"): if not hasattr(ev, "redacted"):
logger.debug("Doesn't have redacted key: %s", ev) logger.debug("Doesn't have redacted key: %s", ev)

View File

@ -69,19 +69,21 @@ class EventFederationStore(SQLBaseStore):
return results return results
def _get_prev_events(self, txn, event_id): def _get_latest_state_in_room(self, txn, room_id, type, state_key):
prev_ids = self._simple_select_onecol_txn( event_ids = self._simple_select_onecol_txn(
txn, txn,
table="event_edges", table="state_forward_extremities",
keyvalues={ keyvalues={
"event_id": event_id, "room_id": room_id,
"type": type,
"state_key": state_key,
}, },
retcol="prev_event_id", retcol="event_id",
) )
results = [] results = []
for prev_event_id in prev_ids: for event_id in event_ids:
hashes = self._get_event_reference_hashes_txn(txn, prev_event_id) hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = { prev_hashes = {
k: encode_base64(v) for k, v in hashes.items() k: encode_base64(v) for k, v in hashes.items()
if k == "sha256" if k == "sha256"
@ -90,6 +92,53 @@ class EventFederationStore(SQLBaseStore):
return results return results
def _get_prev_events(self, txn, event_id):
results = self._get_prev_events_and_state(
txn,
event_id,
is_state=0,
)
return [(e_id, h, ) for e_id, h, _ in results]
def _get_prev_state(self, txn, event_id):
results = self._get_prev_events_and_state(
txn,
event_id,
is_state=1,
)
return [(e_id, h, ) for e_id, h, _ in results]
def _get_prev_events_and_state(self, txn, event_id, is_state=None):
keyvalues = {
"event_id": event_id,
}
if is_state is not None:
keyvalues["is_state"] = is_state
res = self._simple_select_list_txn(
txn,
table="event_edges",
keyvalues=keyvalues,
retcols=["prev_event_id", "is_state"],
)
results = []
for d in res:
hashes = self._get_event_reference_hashes_txn(
txn,
d["prev_event_id"]
)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
if k == "sha256"
}
results.append((d["prev_event_id"], prev_hashes, d["is_state"]))
return results
def get_min_depth(self, room_id): def get_min_depth(self, room_id):
return self.runInteraction( return self.runInteraction(
"get_min_depth", "get_min_depth",
@ -135,6 +184,7 @@ class EventFederationStore(SQLBaseStore):
"event_id": event_id, "event_id": event_id,
"prev_event_id": e_id, "prev_event_id": e_id,
"room_id": room_id, "room_id": room_id,
"is_state": 0,
}, },
or_ignore=True, or_ignore=True,
) )

View File

@ -1,7 +1,7 @@
CREATE TABLE IF NOT EXISTS event_forward_extremities( CREATE TABLE IF NOT EXISTS event_forward_extremities(
event_id TEXT, event_id TEXT NOT NULL,
room_id TEXT, room_id TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
); );
@ -10,8 +10,8 @@ CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_backward_extremities( CREATE TABLE IF NOT EXISTS event_backward_extremities(
event_id TEXT, event_id TEXT NOT NULL,
room_id TEXT, room_id TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
); );
@ -20,10 +20,11 @@ CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id
CREATE TABLE IF NOT EXISTS event_edges( CREATE TABLE IF NOT EXISTS event_edges(
event_id TEXT, event_id TEXT NOT NULL,
prev_event_id TEXT, prev_event_id TEXT NOT NULL,
room_id TEXT, room_id TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id) is_state INTEGER NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id, is_state)
); );
CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id); CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id);
@ -31,8 +32,8 @@ CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id);
CREATE TABLE IF NOT EXISTS room_depth( CREATE TABLE IF NOT EXISTS room_depth(
room_id TEXT, room_id TEXT NOT NULL,
min_depth INTEGER, min_depth INTEGER NOT NULL,
CONSTRAINT uniqueness UNIQUE (room_id) CONSTRAINT uniqueness UNIQUE (room_id)
); );
@ -40,10 +41,25 @@ CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id);
create TABLE IF NOT EXISTS event_destinations( create TABLE IF NOT EXISTS event_destinations(
event_id TEXT, event_id TEXT NOT NULL,
destination TEXT, destination TEXT NOT NULL,
delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE
); );
CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id); CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id);
CREATE TABLE IF NOT EXISTS state_forward_extremities(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
);
CREATE INDEX IF NOT EXISTS st_extrem_keys ON state_forward_extremities(
room_id, type, state_key
);
CREATE INDEX IF NOT EXISTS st_extrem_id ON state_forward_extremities(event_id);

View File

@ -80,7 +80,7 @@ class JsonEncodedObject(object):
def get_full_dict(self): def get_full_dict(self):
d = { d = {
k: v for (k, v) in self.__dict__.items() k: _encode(v) for (k, v) in self.__dict__.items()
if k in self.valid_keys or k in self.internal_keys if k in self.valid_keys or k in self.internal_keys
} }
d.update(self.unrecognized_keys) d.update(self.unrecognized_keys)