Fix limit logic for EventsStream (#7358)

* Factor out functions for injecting events into database

I want to add some more flexibility to the tools for injecting events into the
database, and I don't want to clutter up HomeserverTestCase with them, so let's
factor them out to a new file.

* Rework TestReplicationDataHandler

This wasn't very easy to work with: the mock wrapping was largely superfluous,
and it's useful to be able to inspect the received rows, and clear out the
received list.

* Fix AssertionErrors being thrown by EventsStream

Part of the problem was that there was an off-by-one error in the assertion,
but also the limit logic was too simple. Fix it all up and add some tests.
This commit is contained in:
Richard van der Hoff 2020-04-29 12:30:36 +01:00 committed by GitHub
parent eeef9633af
commit c2e1a2110f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 658 additions and 67 deletions

1
changelog.d/7358.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.

View File

@ -87,7 +87,9 @@ class ReplicationCommandHandler:
stream.NAME: stream(hs) for stream in STREAMS_MAP.values() stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream] } # type: Dict[str, Stream]
self._position_linearizer = Linearizer("replication_position") self._position_linearizer = Linearizer(
"replication_position", clock=self._clock
)
# Map of stream to batched updates. See RdataCommand for info on how # Map of stream to batched updates. See RdataCommand for info on how
# batching works. # batching works.

View File

@ -170,22 +170,16 @@ class EventsStream(Stream):
limited = False limited = False
upper_limit = current_token upper_limit = current_token
# next up is the state delta table # next up is the state delta table.
(
state_rows = await self._store.get_all_updated_current_state_deltas( state_rows,
upper_limit,
state_rows_limited,
) = await self._store.get_all_updated_current_state_deltas(
from_token, upper_limit, target_row_count from_token, upper_limit, target_row_count
) # type: List[Tuple] )
# again, if we've hit the limit there, we'll need to limit the other sources limited = limited or state_rows_limited
assert len(state_rows) < target_row_count
if len(state_rows) == target_row_count:
assert state_rows[-1][0] <= upper_limit
upper_limit = state_rows[-1][0]
limited = True
# FIXME: is it a given that there is only one row per stream_id in the
# state_deltas table (so that we can be sure that we have got all of the
# rows for upper_limit)?
# finally, fetch the ex-outliers rows. We assume there are few enough of these # finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit. # not to bother with the limit.

View File

@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender import synapse.server_notices.server_notices_sender
import synapse.state import synapse.state
import synapse.storage import synapse.storage
from synapse.events.builder import EventBuilderFactory
class HomeServer(object): class HomeServer(object):
@property @property
@ -121,3 +122,7 @@ class HomeServer(object):
pass pass
def get_instance_id(self) -> str: def get_instance_id(self) -> str:
pass pass
def get_event_builder_factory(self) -> EventBuilderFactory:
pass
def get_storage(self) -> synapse.storage.Storage:
pass

View File

@ -19,7 +19,7 @@ import itertools
import logging import logging
import threading import threading
from collections import namedtuple from collections import namedtuple
from typing import List, Optional from typing import List, Optional, Tuple
from canonicaljson import json from canonicaljson import json
from constantly import NamedConstant, Names from constantly import NamedConstant, Names
@ -1084,7 +1084,28 @@ class EventsWorkerStore(SQLBaseStore):
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
) )
def get_all_updated_current_state_deltas(self, from_token, to_token, limit): async def get_all_updated_current_state_deltas(
self, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple], int, bool]:
"""Fetch updates from current_state_delta_stream
Args:
from_token: The previous stream token. Updates from this stream id will
be excluded.
to_token: The current stream token (ie the upper limit). Updates up to this
stream id will be included (modulo the 'limit' param)
target_row_count: The number of rows to try to return. If more rows are
available, we will set 'limited' in the result. In the event of a large
batch, we may return more rows than this.
Returns:
A triplet `(updates, new_last_token, limited)`, where:
* `updates` is a list of database tuples.
* `new_last_token` is the new position in stream.
* `limited` is whether there are more updates to fetch.
"""
def get_all_updated_current_state_deltas_txn(txn): def get_all_updated_current_state_deltas_txn(txn):
sql = """ sql = """
SELECT stream_id, room_id, type, state_key, event_id SELECT stream_id, room_id, type, state_key, event_id
@ -1092,10 +1113,45 @@ class EventsWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ? WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ? ORDER BY stream_id ASC LIMIT ?
""" """
txn.execute(sql, (from_token, to_token, limit)) txn.execute(sql, (from_token, to_token, target_row_count))
return txn.fetchall() return txn.fetchall()
return self.db.runInteraction( def get_deltas_for_stream_id_txn(txn, stream_id):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
return txn.fetchall()
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
rows = await self.db.runInteraction(
"get_all_updated_current_state_deltas", "get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn, get_all_updated_current_state_deltas_txn,
) # type: List[Tuple]
# if we've got fewer rows than the limit, we're good
if len(rows) < target_row_count:
return rows, to_token, False
# we hit the limit, so reduce the upper limit so that we exclude the stream id
# of the last row in the result.
assert rows[-1][0] <= to_token
to_token = rows[-1][0] - 1
# search backwards through the list for the point to truncate
for idx in range(len(rows) - 1, 0, -1):
if rows[idx - 1][0] <= to_token:
return rows[:idx], to_token, True
# bother. We didn't get a full set of changes for even a single
# stream id. let's run the query again, without a row limit, but for
# just one stream id.
to_token += 1
rows = await self.db.runInteraction(
"get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
) )
return rows, to_token, True

View File

@ -12,10 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from typing import Optional
from mock import Mock import logging
from typing import Any, Dict, List, Optional, Tuple
import attr import attr
@ -25,6 +24,7 @@ from twisted.web.http import HTTPChannel
from synapse.app.generic_worker import GenericWorkerServer from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
@ -65,9 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# databases objects are the same. # databases objects are the same.
self.worker_hs.get_datastore().db = hs.get_datastore().db self.worker_hs.get_datastore().db = hs.get_datastore().db
self.test_handler = Mock( self.test_handler = self._build_replication_data_handler()
wraps=TestReplicationDataHandler(self.worker_hs.get_datastore())
)
self.worker_hs.replication_data_handler = self.test_handler self.worker_hs.replication_data_handler = self.test_handler
repl_handler = ReplicationCommandHandler(self.worker_hs) repl_handler = ReplicationCommandHandler(self.worker_hs)
@ -78,6 +76,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None self._client_transport = None
self._server_transport = None self._server_transport = None
def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs.get_datastore())
def reconnect(self): def reconnect(self):
if self._client_transport: if self._client_transport:
self.client.close() self.client.close()
@ -174,22 +175,28 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
class TestReplicationDataHandler(ReplicationDataHandler): class TestReplicationDataHandler(ReplicationDataHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows""" """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
def __init__(self, hs): def __init__(self, store: BaseSlavedStore):
super().__init__(hs) super().__init__(store)
self.streams = set()
self._received_rdata_rows = [] # streams to subscribe to: map from stream id to position
self.stream_positions = {} # type: Dict[str, int]
# list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
def get_streams_to_replicate(self): def get_streams_to_replicate(self):
positions = {s: 0 for s in self.streams} return self.stream_positions
for stream, token, _ in self._received_rdata_rows:
if stream in self.streams:
positions[stream] = max(token, positions.get(stream, 0))
return positions
async def on_rdata(self, stream_name, token, rows): async def on_rdata(self, stream_name, token, rows):
await super().on_rdata(stream_name, token, rows) await super().on_rdata(stream_name, token, rows)
for r in rows: for r in rows:
self._received_rdata_rows.append((stream_name, token, r)) self.received_rdata_rows.append((stream_name, token, r))
if (
stream_name in self.stream_positions
and token > self.stream_positions[stream_name]
):
self.stream_positions[stream_name] = token
@attr.s() @attr.s()
@ -221,7 +228,7 @@ class _PushHTTPChannel(HTTPChannel):
super().__init__() super().__init__()
self.reactor = reactor self.reactor = reactor
self._pull_to_push_producer = None self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
def registerProducer(self, producer, streaming): def registerProducer(self, producer, streaming):
# Convert pull producers to push producer. # Convert pull producers to push producer.

View File

@ -0,0 +1,417 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests.replication.tcp.streams._base import BaseStreamTestCase
from tests.test_utils.event_injection import inject_event, inject_member_event
class EventsStreamTestCase(BaseStreamTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor, clock, hs):
super().prepare(reactor, clock, hs)
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
self.reconnect()
self.test_handler.stream_positions["events"] = 0
self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear()
def test_update_function_event_row_limit(self):
"""Test replication with many non-state events
Checks that all events are correctly replicated when there are lots of
event rows to be replicated.
"""
# disconnect, so that we can stack up some changes
self.disconnect()
# generate lots of non-state events. We inject them using inject_event
# so that they are not send out over replication until we call self.replicate().
events = [
self._inject_test_event()
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1)
]
# also one state event
state_event = self._inject_state_event()
# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
# now reconnect to pull the updates
self.reconnect()
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
for event in events:
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, event.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertIsInstance(row, EventsStreamRow)
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state_event.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state_event.event_id)
self.assertEqual([], received_rows)
def test_update_function_huge_state_change(self):
"""Test replication with many state events
Ensures that all events are correctly replicated when there are lots of
state change rows to be replicated.
"""
# we want to generate lots of state changes at a single stream ID.
#
# We do this by having two branches in the DAG. On one, we have a moderator
# which that generates lots of state; on the other, we de-op the moderator,
# thus invalidating all the state.
OTHER_USER = "@other_user:localhost"
# have the user join
inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
self.room_id, EventTypes.PowerLevels, tok=self.user_tok
)
pls["users"][OTHER_USER] = 50
self.helper.send_state(
self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
)
# this is the point in the DAG where we make a fork
fork_point = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
) # type: List[str]
events = [
self._inject_state_event(sender=OTHER_USER)
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT)
]
self.replicate()
# all those events and state changes should have landed
self.assertGreaterEqual(
len(self.test_handler.received_rdata_rows), 2 * len(events)
)
# disconnect, so that we can stack up the changes
self.disconnect()
self.test_handler.received_rdata_rows.clear()
# a state event which doesn't get rolled back, to check that the state
# before the huge update comes through ok
state1 = self._inject_state_event()
# roll back all the state by de-modding the user
prev_events = fork_point
pls["users"][OTHER_USER] = 0
pl_event = inject_event(
self.hs,
prev_event_ids=prev_events,
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
room_id=self.room_id,
content=pls,
)
# one more bit of state that doesn't get rolled back
state2 = self._inject_state_event()
# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
# now reconnect to pull the updates
self.reconnect()
self.replicate()
# now we should have received all the expected rows in the right order.
#
# we expect:
#
# - two rows for state1
# - the PL event row, plus state rows for the PL event and each
# of the states that got reverted.
# - two rows for state2
received_rows = self.test_handler.received_rdata_rows
# first check the first two rows, which should be state1
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state1.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state1.event_id)
# now the last two rows, which should be state2
stream_name, token, row = received_rows.pop(-2)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state2.event_id)
stream_name, token, row = received_rows.pop(-1)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state2.event_id)
# that should leave us with the rows for the PL event
self.assertEqual(len(received_rows), len(events) + 2)
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, pl_event.event_id)
# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
for stream_name, token, row in received_rows:
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
state_rows.append(row.data)
state_rows.sort(key=lambda r: r.state_key)
sr = state_rows.pop(0)
self.assertEqual(sr.type, EventTypes.PowerLevels)
self.assertEqual(sr.event_id, pl_event.event_id)
for sr in state_rows:
self.assertEqual(sr.type, "test_state_event")
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)
def test_update_function_state_row_limit(self):
"""Test replication with many state events over several stream ids.
"""
# we want to generate lots of state changes, but for this test, we want to
# spread out the state changes over a few stream IDs.
#
# We do this by having two branches in the DAG. On one, we have four moderators,
# each of which that generates lots of state; on the other, we de-op the users,
# thus invalidating all the state.
NUM_USERS = 4
STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1
user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)]
# have the users join
for u in user_ids:
inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
self.room_id, EventTypes.PowerLevels, tok=self.user_tok
)
pls["users"].update({u: 50 for u in user_ids})
self.helper.send_state(
self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
)
# this is the point in the DAG where we make a fork
fork_point = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
) # type: List[str]
events = [] # type: List[EventBase]
for user in user_ids:
events.extend(
self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
)
self.replicate()
# all those events and state changes should have landed
self.assertGreaterEqual(
len(self.test_handler.received_rdata_rows), 2 * len(events)
)
# disconnect, so that we can stack up the changes
self.disconnect()
self.test_handler.received_rdata_rows.clear()
# now roll back all that state by de-modding the users
prev_events = fork_point
pl_events = []
for u in user_ids:
pls["users"][u] = 0
e = inject_event(
self.hs,
prev_event_ids=prev_events,
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
room_id=self.room_id,
content=pls,
)
prev_events = [e.event_id]
pl_events.append(e)
# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
# now reconnect to pull the updates
self.reconnect()
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
self.assertGreaterEqual(len(received_rows), len(events))
for i in range(NUM_USERS):
# for each user, we expect the PL event row, followed by state rows for
# the PL event and each of the states that got reverted.
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, pl_events[i].event_id)
# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
for j in range(STATES_PER_USER + 1):
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
state_rows.append(row.data)
state_rows.sort(key=lambda r: r.state_key)
sr = state_rows.pop(0)
self.assertEqual(sr.type, EventTypes.PowerLevels)
self.assertEqual(sr.event_id, pl_events[i].event_id)
for sr in state_rows:
self.assertEqual(sr.type, "test_state_event")
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)
self.assertEqual([], received_rows)
event_count = 0
def _inject_test_event(
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
) -> EventBase:
if sender is None:
sender = self.user_id
if body is None:
body = "event %i" % (self.event_count,)
self.event_count += 1
return inject_event(
self.hs,
room_id=self.room_id,
sender=sender,
type="test_event",
content={"body": body},
**kwargs
)
def _inject_state_event(
self,
body: Optional[str] = None,
state_key: Optional[str] = None,
sender: Optional[str] = None,
) -> EventBase:
if sender is None:
sender = self.user_id
if state_key is None:
state_key = "state_%i" % (self.event_count,)
self.event_count += 1
if body is None:
body = "state event %s" % (state_key,)
return inject_event(
self.hs,
room_id=self.room_id,
sender=sender,
type="test_state_event",
state_key=state_key,
content={"body": body},
)

View File

@ -12,6 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# type: ignore
from mock import Mock
from synapse.replication.tcp.streams._base import ReceiptsStream from synapse.replication.tcp.streams._base import ReceiptsStream
from tests.replication.tcp.streams._base import BaseStreamTestCase from tests.replication.tcp.streams._base import BaseStreamTestCase
@ -20,11 +25,14 @@ USER_ID = "@feeling:blue"
class ReceiptsStreamTestCase(BaseStreamTestCase): class ReceiptsStreamTestCase(BaseStreamTestCase):
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
def test_receipt(self): def test_receipt(self):
self.reconnect() self.reconnect()
# make the client subscribe to the receipts stream # make the client subscribe to the receipts stream
self.test_handler.streams.add("receipts") self.test_handler.stream_positions.update({"receipts": 0})
# tell the master to send a new receipt # tell the master to send a new receipt
self.get_success( self.get_success(

View File

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mock import Mock
from synapse.handlers.typing import RoomMember from synapse.handlers.typing import RoomMember
from synapse.replication.http import streams from synapse.replication.http import streams
from synapse.replication.tcp.streams import TypingStream from synapse.replication.tcp.streams import TypingStream
@ -26,6 +28,9 @@ class TypingStreamTestCase(BaseStreamTestCase):
streams.register_servlets, streams.register_servlets,
] ]
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
def test_typing(self): def test_typing(self):
typing = self.hs.get_typing_handler() typing = self.hs.get_typing_handler()
@ -33,8 +38,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.reconnect() self.reconnect()
# make the client subscribe to the receipts stream # make the client subscribe to the typing stream
self.test_handler.streams.add("typing") self.test_handler.stream_positions.update({"typing": 0})
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
@ -75,6 +80,6 @@ class TypingStreamTestCase(BaseStreamTestCase):
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing") self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow row = rdata_rows[0]
self.assertEqual(room_id, row.room_id) self.assertEqual(room_id, row.room_id)
self.assertEqual([], row.user_ids) self.assertEqual([], row.user_ids)

View File

@ -39,7 +39,7 @@ class RestHelper(object):
resource = attr.ib() resource = attr.ib()
auth_user_id = attr.ib() auth_user_id = attr.ib()
def create_room_as(self, room_creator, is_public=True, tok=None): def create_room_as(self, room_creator=None, is_public=True, tok=None):
temp_id = self.auth_user_id temp_id = self.auth_user_id
self.auth_user_id = room_creator self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom" path = "/_matrix/client/r0/createRoom"

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd # Copyright 2019 New Vector Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,3 +17,22 @@
""" """
Utilities for running the unit tests Utilities for running the unit tests
""" """
from typing import Awaitable, TypeVar
TV = TypeVar("TV")
def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
"""Get the result from an Awaitable which should have completed
Asserts that the given awaitable has a result ready, and returns its value
"""
i = awaitable.__await__()
try:
next(i)
except StopIteration as e:
# awaitable returned a result
return e.value
# if next didn't raise, the awaitable hasn't completed.
raise Exception("awaitable has not yet completed")

View File

@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import synapse.server
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.types import Collection
from tests.test_utils import get_awaitable_result
"""
Utility functions for poking events into the storage of the server under test.
"""
def inject_member_event(
hs: synapse.server.HomeServer,
room_id: str,
sender: str,
membership: str,
target: Optional[str] = None,
extra_content: Optional[dict] = None,
**kwargs
) -> EventBase:
"""Inject a membership event into a room."""
if target is None:
target = sender
content = {"membership": membership}
if extra_content:
content.update(extra_content)
return inject_event(
hs,
room_id=room_id,
type=EventTypes.Member,
sender=sender,
state_key=target,
content=content,
**kwargs
)
def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
**kwargs
) -> EventBase:
"""Inject a generic event into a room
Args:
hs: the homeserver under test
room_version: the version of the room we're inserting into.
if not specified, will be looked up
prev_event_ids: prev_events for the event. If not specified, will be looked up
kwargs: fields for the event to be created
"""
test_reactor = hs.get_reactor()
if room_version is None:
d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
test_reactor.advance(0)
room_version = get_awaitable_result(d)
builder = hs.get_event_builder_factory().for_room_version(
KNOWN_ROOM_VERSIONS[room_version], kwargs
)
d = hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)
test_reactor.advance(0)
event, context = get_awaitable_result(d)
d = hs.get_storage().persistence.persist_event(event, context)
test_reactor.advance(0)
get_awaitable_result(d)
return event

View File

@ -32,7 +32,6 @@ from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest from twisted.trial import unittest
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.federation.transport import server as federation_server from synapse.federation.transport import server as federation_server
@ -55,6 +54,7 @@ from tests.server import (
render, render,
setup_test_homeserver, setup_test_homeserver,
) )
from tests.test_utils import event_injection
from tests.test_utils.logging_setup import setup_logging from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb from tests.utils import default_config, setupdb
@ -596,36 +596,14 @@ class HomeserverTestCase(TestCase):
""" """
Inject a membership event into a room. Inject a membership event into a room.
Deprecated: use event_injection.inject_room_member directly
Args: Args:
room: Room ID to inject the event into. room: Room ID to inject the event into.
user: MXID of the user to inject the membership for. user: MXID of the user to inject the membership for.
membership: The membership type. membership: The membership type.
""" """
event_builder_factory = self.hs.get_event_builder_factory() event_injection.inject_member_event(self.hs, room, user, membership)
event_creation_handler = self.hs.get_event_creation_handler()
room_version = self.get_success(
self.hs.get_datastore().get_room_version_id(room)
)
builder = event_builder_factory.for_room_version(
KNOWN_ROOM_VERSIONS[room_version],
{
"type": EventTypes.Member,
"sender": user,
"state_key": user,
"room_id": room,
"content": {"membership": membership},
},
)
event, context = self.get_success(
event_creation_handler.create_new_client_event(builder)
)
self.get_success(
self.hs.get_storage().persistence.persist_event(event, context)
)
class FederatingHomeserverTestCase(HomeserverTestCase): class FederatingHomeserverTestCase(HomeserverTestCase):

View File

@ -204,6 +204,8 @@ commands = mypy \
synapse/storage/database.py \ synapse/storage/database.py \
synapse/streams \ synapse/streams \
synapse/util/caches/stream_change_cache.py \ synapse/util/caches/stream_change_cache.py \
tests/replication/tcp/streams \
tests/test_utils \
tests/util/test_stream_change_cache.py tests/util/test_stream_change_cache.py
# To find all folders that pass mypy you run: # To find all folders that pass mypy you run: