Move to storing state_groups_state as deltas

This commit is contained in:
Erik Johnston 2016-09-01 14:31:26 +01:00
parent 0cfd6c3161
commit 9e25443db8
5 changed files with 172 additions and 62 deletions

View File

@ -15,9 +15,25 @@
class EventContext(object): class EventContext(object):
__slots__ = [
"current_state_ids",
"prev_state_ids",
"state_group",
"rejected",
"push_actions",
"prev_group",
"delta_ids",
"prev_state_events",
]
def __init__(self): def __init__(self):
self.current_state_ids = None self.current_state_ids = None
self.prev_state_ids = None self.prev_state_ids = None
self.state_group = None self.state_group = None
self.rejected = False self.rejected = False
self.push_actions = [] self.push_actions = []
self.prev_group = None
self.delta_ids = None
self.prev_state_events = None

View File

@ -54,12 +54,15 @@ def _gen_state_id():
class _StateCacheEntry(object): class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id"] __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
def __init__(self, state, state_group): def __init__(self, state, state_group, prev_group=None, delta_ids=None):
self.state = state self.state = state
self.state_group = state_group self.state_group = state_group
self.prev_group = prev_group
self.delta_ids = delta_ids
# The `state_id` is a unique ID we generate that can be used as ID for # The `state_id` is a unique ID we generate that can be used as ID for
# this collection of state. Usually this would be the same as the # this collection of state. Usually this would be the same as the
# state group, but on worker instances we can't generate a new state # state group, but on worker instances we can't generate a new state
@ -243,11 +246,20 @@ class StateHandler(object):
if key in context.prev_state_ids: if key in context.prev_state_ids:
replaces = context.prev_state_ids[key] replaces = context.prev_state_ids[key]
event.unsigned["replaces_state"] = replaces event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids) context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id context.current_state_ids[key] = event.event_id
context.prev_group = entry.prev_group
context.delta_ids = entry.delta_ids
if context.delta_ids is not None:
context.delta_ids[key] = event.event_id
else: else:
context.current_state_ids = context.prev_state_ids context.current_state_ids = context.prev_state_ids
context.prev_group = entry.prev_group
context.delta_ids = entry.delta_ids
context.prev_state_events = [] context.prev_state_events = []
defer.returnValue(context) defer.returnValue(context)
@ -281,6 +293,8 @@ class StateHandler(object):
defer.returnValue(_StateCacheEntry( defer.returnValue(_StateCacheEntry(
state=state_list, state=state_list,
state_group=name, state_group=name,
prev_group=name,
delta_ids={},
)) ))
if self._state_cache is not None: if self._state_cache is not None:
@ -330,6 +344,7 @@ class StateHandler(object):
if new_state_event_ids == frozenset(e_id for e_id in events): if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg state_group = sg
break break
if state_group is None: if state_group is None:
# Worker instances don't have access to this method, but we want # Worker instances don't have access to this method, but we want
# to set the state_group on the main instance to increase cache # to set the state_group on the main instance to increase cache
@ -337,9 +352,24 @@ class StateHandler(object):
if hasattr(self.store, "get_next_state_group"): if hasattr(self.store, "get_next_state_group"):
state_group = self.store.get_next_state_group() state_group = self.store.get_next_state_group()
prev_group = None
delta_ids = None
for old_group, old_ids in state_groups_ids.items():
if not set(new_state.iterkeys()) - set(old_ids.iterkeys()):
n_delta_ids = {
k: v
for k, v in new_state.items()
if old_ids.get(k) != v
}
if not delta_ids or len(n_delta_ids) < len(delta_ids):
prev_group = old_group
delta_ids = n_delta_ids
cache = _StateCacheEntry( cache = _StateCacheEntry(
state=new_state, state=new_state,
state_group=state_group, state_group=state_group,
prev_group=prev_group,
delta_ids=delta_ids,
) )
if self._state_cache is not None: if self._state_cache is not None:

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 34 SCHEMA_VERSION = 35
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@ -0,0 +1,21 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE state_group_edges(
state_group BIGINT NOT NULL,
prev_state_group BIGINT NOT NULL
);
CREATE INDEX state_group_edges_idx ON state_group_edges(state_group);

View File

@ -16,6 +16,7 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer from twisted.internet import defer
@ -118,20 +119,45 @@ class StateStore(SQLBaseStore):
}, },
) )
self._simple_insert_many_txn( if context.prev_group:
txn, self._simple_insert_txn(
table="state_groups_state", txn,
values=[ table="state_group_edges",
{ values={
"state_group": context.state_group, "state_group": context.state_group,
"room_id": event.room_id, "prev_state_group": context.prev_group,
"type": key[0], },
"state_key": key[1], )
"event_id": state_id,
} self._simple_insert_many_txn(
for key, state_id in state_event_ids.items() txn,
], table="state_groups_state",
) values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.delta_ids.items()
],
)
else:
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in state_event_ids.items()
],
)
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
@ -214,26 +240,70 @@ class StateStore(SQLBaseStore):
else: else:
where_clause = "" where_clause = ""
sql = (
"SELECT state_group, event_id, type, state_key"
" FROM state_groups_state WHERE"
" state_group IN (%s) %s" % (
",".join("?" for _ in groups),
where_clause,
)
)
args = list(groups)
if types is not None:
args.extend([i for typ in types for i in typ])
txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
results = {group: {} for group in groups} results = {group: {} for group in groups}
for row in rows: if isinstance(self.database_engine, PostgresEngine):
key = (row["type"], row["state_key"]) sql = ("""
results[row["state_group"]][key] = row["event_id"] WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT type, state_key, event_id FROM state_groups_state
WHERE ROW(type, state_key, state_group) IN (
SELECT type, state_key, max(state_group) FROM state
INNER JOIN state_groups_state USING (state_group)
GROUP BY type, state_key
)
%s;
""") % (where_clause,)
for group in groups:
args = [group]
if types is not None:
args.extend([i for typ in types for i in typ])
txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
for row in rows:
key = (row["type"], row["state_key"])
results[group][key] = row["event_id"]
else:
for group in groups:
group_tree = [group]
next_group = group
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
group_tree.append(next_group)
sql = ("""
SELECT type, state_key, event_id FROM state_groups_state
INNER JOIN (
SELECT type, state_key, max(state_group) as state_group
FROM state_groups_state
WHERE state_group IN (%s) %s
GROUP BY type, state_key
) USING (type, state_key, state_group);
""") % (",".join("?" for _ in group_tree), where_clause,)
args = list(group_tree)
if types is not None:
args.extend([i for typ in types for i in typ])
txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
for row in rows:
key = (row["type"], row["state_key"])
results[group][key] = row["event_id"]
return results return results
results = {} results = {}
@ -504,32 +574,5 @@ class StateStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
def get_all_new_state_groups(self, last_id, current_id, limit):
def get_all_new_state_groups_txn(txn):
sql = (
"SELECT id, room_id, event_id FROM state_groups"
" WHERE ? < id AND id <= ? ORDER BY id LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
groups = txn.fetchall()
if not groups:
return ([], [])
lower_bound = groups[0][0]
upper_bound = groups[-1][0]
sql = (
"SELECT state_group, type, state_key, event_id"
" FROM state_groups_state"
" WHERE ? <= state_group AND state_group <= ?"
)
txn.execute(sql, (lower_bound, upper_bound))
state_group_state = txn.fetchall()
return (groups, state_group_state)
return self.runInteraction(
"get_all_new_state_groups", get_all_new_state_groups_txn
)
def get_next_state_group(self): def get_next_state_group(self):
return self._state_groups_id_gen.get_next() return self._state_groups_id_gen.get_next()