Merge pull request #2378 from matrix-org/erikj/group_sync_support

Add groups to sync stream
This commit is contained in:
Erik Johnston 2017-07-21 11:05:39 +01:00 committed by GitHub
commit 96917d5552
12 changed files with 283 additions and 12 deletions

View File

@ -41,6 +41,7 @@ from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.groups import SlavedGroupServerStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
@ -75,6 +76,7 @@ class SynchrotronSlavedStore(
SlavedRegistrationStore,
SlavedFilteringStore,
SlavedPresenceStore,
SlavedGroupServerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
SlavedClientIpStore,
@ -409,6 +411,10 @@ class SyncReplicationHandler(ReplicationClientHandler):
)
elif stream_name == "presence":
yield self.presence_handler.process_replication_rows(token, rows)
elif stream_name == "receipts":
self.notifier.on_new_event(
"groups_key", token, users=[row.user_id for row in rows],
)
def start(config_options):

View File

@ -63,6 +63,7 @@ class GroupsLocalHandler(object):
self.is_mine_id = hs.is_mine_id
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
self.notifier = hs.get_notifier()
self.attestations = hs.get_groups_attestation_signing()
# Ensure attestations get renewed
@ -212,13 +213,16 @@ class GroupsLocalHandler(object):
user_id=user_id,
)
yield self.store.register_user_group_membership(
token = yield self.store.register_user_group_membership(
group_id, user_id,
membership="join",
is_admin=False,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
defer.returnValue({})
@ -258,11 +262,14 @@ class GroupsLocalHandler(object):
if "avatar_url" in content["profile"]:
local_profile["avatar_url"] = content["profile"]["avatar_url"]
yield self.store.register_user_group_membership(
token = yield self.store.register_user_group_membership(
group_id, user_id,
membership="invite",
content={"profile": local_profile, "inviter": content["inviter"]},
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
defer.returnValue({"state": "invite"})
@ -271,10 +278,13 @@ class GroupsLocalHandler(object):
"""Remove a user from a group
"""
if user_id == requester_user_id:
yield self.store.register_user_group_membership(
token = yield self.store.register_user_group_membership(
group_id, user_id,
membership="leave",
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
# TODO: Should probably remember that we tried to leave so that we can
# retry if the group server is currently down.
@ -297,10 +307,13 @@ class GroupsLocalHandler(object):
"""One of our users was removed/kicked from a group
"""
# TODO: Check if user in group
yield self.store.register_user_group_membership(
token = yield self.store.register_user_group_membership(
group_id, user_id,
membership="leave",
)
self.notifier.on_new_event(
"groups_key", token, users=[user_id],
)
@defer.inlineCallbacks
def get_joined_groups(self, user_id):

View File

@ -108,6 +108,17 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
return True
class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
"join",
"invite",
"leave",
])):
__slots__ = []
def __nonzero__(self):
return bool(self.join or self.invite or self.leave)
class SyncResult(collections.namedtuple("SyncResult", [
"next_batch", # Token for the next sync
"presence", # List of presence events for the user.
@ -119,6 +130,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
"device_lists", # List of user_ids whose devices have chanegd
"device_one_time_keys_count", # Dict of algorithm to count for one time keys
# for this device
"groups",
])):
__slots__ = []
@ -134,7 +146,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
self.archived or
self.account_data or
self.to_device or
self.device_lists
self.device_lists or
self.groups
)
@ -560,6 +573,8 @@ class SyncHandler(object):
user_id, device_id
)
yield self._generate_sync_entry_for_groups(sync_result_builder)
defer.returnValue(SyncResult(
presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data,
@ -568,10 +583,56 @@ class SyncHandler(object):
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
groups=sync_result_builder.groups,
device_one_time_keys_count=one_time_key_counts,
next_batch=sync_result_builder.now_token,
))
@measure_func("_generate_sync_entry_for_groups")
@defer.inlineCallbacks
def _generate_sync_entry_for_groups(self, sync_result_builder):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token
if since_token and since_token.groups_key:
results = yield self.store.get_groups_changes_for_user(
user_id, since_token.groups_key, now_token.groups_key,
)
else:
results = yield self.store.get_all_groups_for_user(
user_id, now_token.groups_key,
)
invited = {}
joined = {}
left = {}
for result in results:
membership = result["membership"]
group_id = result["group_id"]
gtype = result["type"]
content = result["content"]
if membership == "join":
if gtype == "membership":
content.pop("membership", None)
invited[group_id] = content["content"]
else:
joined.setdefault(group_id, {})[gtype] = content
elif membership == "invite":
if gtype == "membership":
content.pop("membership", None)
invited[group_id] = content["content"]
else:
if gtype == "membership":
left[group_id] = content["content"]
sync_result_builder.groups = GroupsSyncResult(
join=joined,
invite=invited,
leave=left,
)
@measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks
def _generate_sync_entry_for_device_list(self, sync_result_builder):
@ -1260,6 +1321,7 @@ class SyncResultBuilder(object):
self.invited = []
self.archived = []
self.device = []
self.groups = None
class RoomSyncResultBuilder(object):

View File

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedGroupServerStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedGroupServerStore, self).__init__(db_conn, hs)
self.hs = hs
self._group_updates_id_gen = SlavedIdTracker(
db_conn, "local_group_updates", "stream_id",
)
self._group_updates_stream_cache = StreamChangeCache(
"_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
)
get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__
get_group_stream_token = DataStore.get_group_stream_token.__func__
get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__
def stream_positions(self):
result = super(SlavedGroupServerStore, self).stream_positions()
result["groups"] = self._group_updates_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "groups":
self._group_updates_id_gen.advance(token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(
row.user_id, token
)
return super(SlavedGroupServerStore, self).process_replication_rows(
stream_name, token, rows
)

View File

@ -118,6 +118,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
"state_key", # str
"event_id", # str, optional
))
GroupsStreamRow = namedtuple("GroupsStreamRow", (
"group_id", # str
"user_id", # str
"type", # str
"content", # dict
))
class Stream(object):
@ -464,6 +470,19 @@ class CurrentStateDeltaStream(Stream):
super(CurrentStateDeltaStream, self).__init__(hs)
class GroupServerStream(Stream):
NAME = "groups"
ROW_TYPE = GroupsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_group_stream_token
self.update_function = store.get_all_groups_changes
super(GroupServerStream, self).__init__(hs)
STREAMS_MAP = {
stream.NAME: stream
for stream in (
@ -482,5 +501,6 @@ STREAMS_MAP = {
TagAccountDataStream,
AccountDataStream,
CurrentStateDeltaStream,
GroupServerStream,
)
}

View File

@ -199,6 +199,11 @@ class SyncRestServlet(RestServlet):
"invite": invited,
"leave": archived,
},
"groups": {
"join": sync_result.groups.join,
"invite": sync_result.groups.invite,
"leave": sync_result.groups.leave,
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"next_batch": sync_result.next_batch.to_string(),
}

View File

@ -136,6 +136,9 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "pushers", "id",
extra_tables=[("deleted_pushers", "stream_id")],
)
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id",
)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator(
@ -236,6 +239,18 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=curr_state_delta_prefill,
)
_group_updates_prefill, min_group_updates_id = self._get_cache_dict(
db_conn, "local_group_updates",
entity_column="user_id",
stream_column="stream_id",
max_value=self._group_updates_id_gen.get_current_token(),
limit=1000,
)
self._group_updates_stream_cache = StreamChangeCache(
"_group_updates_stream_cache", min_group_updates_id,
prefilled_cache=_group_updates_prefill,
)
cur = LoggingTransaction(
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",

View File

@ -776,7 +776,7 @@ class GroupServerStore(SQLBaseStore):
remote_attestation (dict): If remote group then store the remote
attestation from the group, else None.
"""
def _register_user_group_membership_txn(txn):
def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert?
self._simple_delete_txn(
txn,
@ -798,6 +798,19 @@ class GroupServerStore(SQLBaseStore):
},
)
self._simple_insert_txn(
txn,
table="local_group_updates",
values={
"stream_id": next_id,
"group_id": group_id,
"user_id": user_id,
"type": "membership",
"content": json.dumps({"membership": membership, "content": content}),
}
)
self._group_updates_stream_cache.entity_has_changed(user_id, next_id)
# TODO: Insert profile to ensure it comes down stream if its a join.
if membership == "join":
@ -840,9 +853,12 @@ class GroupServerStore(SQLBaseStore):
},
)
return next_id
with self._group_updates_id_gen.get_next() as next_id:
yield self.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
_register_user_group_membership_txn, next_id,
)
@defer.inlineCallbacks
@ -948,3 +964,68 @@ class GroupServerStore(SQLBaseStore):
retcol="group_id",
desc="get_joined_groups",
)
def get_all_groups_for_user(self, user_id, now_token):
def _get_all_groups_for_user_txn(txn):
sql = """
SELECT group_id, type, membership, u.content
FROM local_group_updates AS u
INNER JOIN local_group_membership USING (group_id, user_id)
WHERE user_id = ? AND membership != 'leave'
AND stream_id <= ?
"""
txn.execute(sql, (user_id, now_token,))
return self.cursor_to_dict(txn)
return self.runInteraction(
"get_all_groups_for_user", _get_all_groups_for_user_txn,
)
def get_groups_changes_for_user(self, user_id, from_token, to_token):
from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_entity_changed(
user_id, from_token,
)
if not has_changed:
return []
def _get_groups_changes_for_user_txn(txn):
sql = """
SELECT group_id, membership, type, u.content
FROM local_group_updates AS u
INNER JOIN local_group_membership USING (group_id, user_id)
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
"""
txn.execute(sql, (user_id, from_token, to_token,))
return [{
"group_id": group_id,
"membership": membership,
"type": gtype,
"content": json.loads(content_json),
} for group_id, membership, gtype, content_json in txn]
return self.runInteraction(
"get_groups_changes_for_user", _get_groups_changes_for_user_txn,
)
def get_all_groups_changes(self, from_token, to_token, limit):
from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_any_entity_changed(
from_token,
)
if not has_changed:
return []
def _get_all_groups_changes_txn(txn):
sql = """
SELECT stream_id, group_id, user_id, type, content
FROM local_group_updates
WHERE ? < stream_id AND stream_id <= ?
LIMIT ?
"""
txn.execute(sql, (from_token, to_token, limit,))
return txn.fetchall()
return self.runInteraction(
"get_all_groups_changes", _get_all_groups_changes_txn,
)
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()

View File

@ -155,3 +155,12 @@ CREATE TABLE local_group_membership (
CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id);
CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id);
CREATE TABLE local_group_updates (
stream_id BIGINT NOT NULL,
group_id TEXT NOT NULL,
user_id TEXT NOT NULL,
type TEXT NOT NULL,
content TEXT NOT NULL
);

View File

@ -45,6 +45,7 @@ class EventSources(object):
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
groups_key = self.store.get_group_stream_token()
token = StreamToken(
room_key=(
@ -65,6 +66,7 @@ class EventSources(object):
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
groups_key=groups_key,
)
defer.returnValue(token)
@ -73,6 +75,7 @@ class EventSources(object):
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
groups_key = self.store.get_group_stream_token()
token = StreamToken(
room_key=(
@ -93,5 +96,6 @@ class EventSources(object):
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
groups_key=groups_key,
)
defer.returnValue(token)

View File

@ -171,6 +171,7 @@ class StreamToken(
"push_rules_key",
"to_device_key",
"device_list_key",
"groups_key",
))
):
_SEPARATOR = "_"
@ -209,6 +210,7 @@ class StreamToken(
or (int(other.push_rules_key) < int(self.push_rules_key))
or (int(other.to_device_key) < int(self.to_device_key))
or (int(other.device_list_key) < int(self.device_list_key))
or (int(other.groups_key) < int(self.groups_key))
)
def copy_and_advance(self, key, new_value):

View File

@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks
def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0_0"
token = "t1-0_0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token))
@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks
def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0_0"
token = "s0_0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token))