mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-01-13 09:29:24 -05:00
Add missing type hints to synapse.replication. (#11938)
This commit is contained in:
parent
8c94b3abe9
commit
d0e78af35e
1
changelog.d/11938.misc
Normal file
1
changelog.d/11938.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add missing type hints to replication code.
|
3
mypy.ini
3
mypy.ini
@ -169,6 +169,9 @@ disallow_untyped_defs = True
|
|||||||
[mypy-synapse.push.*]
|
[mypy-synapse.push.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.replication.*]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.rest.*]
|
[mypy-synapse.rest.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ class SlavedIdTracker(AbstractStreamIdTracker):
|
|||||||
for table, column in extra_tables:
|
for table, column in extra_tables:
|
||||||
self.advance(None, _load_current_id(db_conn, table, column))
|
self.advance(None, _load_current_id(db_conn, table, column))
|
||||||
|
|
||||||
def advance(self, instance_name: Optional[str], new_id: int):
|
def advance(self, instance_name: Optional[str], new_id: int) -> None:
|
||||||
self._current = (max if self.step > 0 else min)(self._current, new_id)
|
self._current = (max if self.step > 0 else min)(self._current, new_id)
|
||||||
|
|
||||||
def get_current_token(self) -> int:
|
def get_current_token(self) -> int:
|
||||||
|
@ -37,7 +37,9 @@ class SlavedClientIpStore(BaseSlavedStore):
|
|||||||
cache_name="client_ip_last_seen", max_size=50000
|
cache_name="client_ip_last_seen", max_size=50000
|
||||||
)
|
)
|
||||||
|
|
||||||
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
|
async def insert_client_ip(
|
||||||
|
self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str
|
||||||
|
) -> None:
|
||||||
now = int(self._clock.time_msec())
|
now = int(self._clock.time_msec())
|
||||||
key = (user_id, access_token, ip)
|
key = (user_id, access_token, ip)
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any, Iterable
|
||||||
|
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||||
@ -60,7 +60,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
|||||||
def get_device_stream_token(self) -> int:
|
def get_device_stream_token(self) -> int:
|
||||||
return self._device_list_id_gen.get_current_token()
|
return self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
def process_replication_rows(
|
||||||
|
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
|
||||||
|
) -> None:
|
||||||
if stream_name == DeviceListsStream.NAME:
|
if stream_name == DeviceListsStream.NAME:
|
||||||
self._device_list_id_gen.advance(instance_name, token)
|
self._device_list_id_gen.advance(instance_name, token)
|
||||||
self._invalidate_caches_for_devices(token, rows)
|
self._invalidate_caches_for_devices(token, rows)
|
||||||
@ -70,7 +72,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
|||||||
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
|
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
|
||||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
|
|
||||||
def _invalidate_caches_for_devices(self, token, rows):
|
def _invalidate_caches_for_devices(
|
||||||
|
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
|
||||||
|
) -> None:
|
||||||
for row in rows:
|
for row in rows:
|
||||||
# The entities are either user IDs (starting with '@') whose devices
|
# The entities are either user IDs (starting with '@') whose devices
|
||||||
# have changed, or remote servers that we need to tell about
|
# have changed, or remote servers that we need to tell about
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any, Iterable
|
||||||
|
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||||
@ -44,10 +44,12 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
|
|||||||
self._group_updates_id_gen.get_current_token(),
|
self._group_updates_id_gen.get_current_token(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_group_stream_token(self):
|
def get_group_stream_token(self) -> int:
|
||||||
return self._group_updates_id_gen.get_current_token()
|
return self._group_updates_id_gen.get_current_token()
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
def process_replication_rows(
|
||||||
|
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
|
||||||
|
) -> None:
|
||||||
if stream_name == GroupServerStream.NAME:
|
if stream_name == GroupServerStream.NAME:
|
||||||
self._group_updates_id_gen.advance(instance_name, token)
|
self._group_updates_id_gen.advance(instance_name, token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Any, Iterable
|
||||||
|
|
||||||
from synapse.replication.tcp.streams import PushRulesStream
|
from synapse.replication.tcp.streams import PushRulesStream
|
||||||
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
|
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
|
||||||
@ -20,10 +21,12 @@ from .events import SlavedEventStore
|
|||||||
|
|
||||||
|
|
||||||
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
||||||
def get_max_push_rules_stream_id(self):
|
def get_max_push_rules_stream_id(self) -> int:
|
||||||
return self._push_rules_stream_id_gen.get_current_token()
|
return self._push_rules_stream_id_gen.get_current_token()
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
def process_replication_rows(
|
||||||
|
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
|
||||||
|
) -> None:
|
||||||
if stream_name == PushRulesStream.NAME:
|
if stream_name == PushRulesStream.NAME:
|
||||||
self._push_rules_stream_id_gen.advance(instance_name, token)
|
self._push_rules_stream_id_gen.advance(instance_name, token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any, Iterable
|
||||||
|
|
||||||
from synapse.replication.tcp.streams import PushersStream
|
from synapse.replication.tcp.streams import PushersStream
|
||||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||||
@ -41,8 +41,8 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
|||||||
return self._pushers_id_gen.get_current_token()
|
return self._pushers_id_gen.get_current_token()
|
||||||
|
|
||||||
def process_replication_rows(
|
def process_replication_rows(
|
||||||
self, stream_name: str, instance_name: str, token, rows
|
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
if stream_name == PushersStream.NAME:
|
if stream_name == PushersStream.NAME:
|
||||||
self._pushers_id_gen.advance(instance_name, token) # type: ignore
|
self._pushers_id_gen.advance(instance_name, token)
|
||||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
|
@ -14,10 +14,12 @@
|
|||||||
"""A replication client for use by synapse workers.
|
"""A replication client for use by synapse workers.
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
|
from twisted.internet.interfaces import IAddress, IConnector
|
||||||
from twisted.internet.protocol import ReconnectingClientFactory
|
from twisted.internet.protocol import ReconnectingClientFactory
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.federation import send_queue
|
from synapse.federation import send_queue
|
||||||
@ -79,10 +81,10 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
|
|||||||
|
|
||||||
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
|
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
|
||||||
|
|
||||||
def startedConnecting(self, connector):
|
def startedConnecting(self, connector: IConnector) -> None:
|
||||||
logger.info("Connecting to replication: %r", connector.getDestination())
|
logger.info("Connecting to replication: %r", connector.getDestination())
|
||||||
|
|
||||||
def buildProtocol(self, addr):
|
def buildProtocol(self, addr: IAddress) -> ClientReplicationStreamProtocol:
|
||||||
logger.info("Connected to replication: %r", addr)
|
logger.info("Connected to replication: %r", addr)
|
||||||
return ClientReplicationStreamProtocol(
|
return ClientReplicationStreamProtocol(
|
||||||
self.hs,
|
self.hs,
|
||||||
@ -92,11 +94,11 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
|
|||||||
self.command_handler,
|
self.command_handler,
|
||||||
)
|
)
|
||||||
|
|
||||||
def clientConnectionLost(self, connector, reason):
|
def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
|
||||||
logger.error("Lost replication conn: %r", reason)
|
logger.error("Lost replication conn: %r", reason)
|
||||||
ReconnectingClientFactory.clientConnectionLost(self, connector, reason)
|
ReconnectingClientFactory.clientConnectionLost(self, connector, reason)
|
||||||
|
|
||||||
def clientConnectionFailed(self, connector, reason):
|
def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
|
||||||
logger.error("Failed to connect to replication: %r", reason)
|
logger.error("Failed to connect to replication: %r", reason)
|
||||||
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
|
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
|
||||||
|
|
||||||
@ -131,7 +133,7 @@ class ReplicationDataHandler:
|
|||||||
|
|
||||||
async def on_rdata(
|
async def on_rdata(
|
||||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||||
):
|
) -> None:
|
||||||
"""Called to handle a batch of replication data with a given stream token.
|
"""Called to handle a batch of replication data with a given stream token.
|
||||||
|
|
||||||
By default this just pokes the slave store. Can be overridden in subclasses to
|
By default this just pokes the slave store. Can be overridden in subclasses to
|
||||||
@ -252,14 +254,16 @@ class ReplicationDataHandler:
|
|||||||
# loop. (This maintains the order so no need to resort)
|
# loop. (This maintains the order so no need to resort)
|
||||||
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
|
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
|
||||||
|
|
||||||
async def on_position(self, stream_name: str, instance_name: str, token: int):
|
async def on_position(
|
||||||
|
self, stream_name: str, instance_name: str, token: int
|
||||||
|
) -> None:
|
||||||
await self.on_rdata(stream_name, instance_name, token, [])
|
await self.on_rdata(stream_name, instance_name, token, [])
|
||||||
|
|
||||||
# We poke the generic "replication" notifier to wake anything up that
|
# We poke the generic "replication" notifier to wake anything up that
|
||||||
# may be streaming.
|
# may be streaming.
|
||||||
self.notifier.notify_replication()
|
self.notifier.notify_replication()
|
||||||
|
|
||||||
def on_remote_server_up(self, server: str):
|
def on_remote_server_up(self, server: str) -> None:
|
||||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||||
|
|
||||||
# Let's wake up the transaction queue for the server in case we have
|
# Let's wake up the transaction queue for the server in case we have
|
||||||
@ -269,7 +273,7 @@ class ReplicationDataHandler:
|
|||||||
|
|
||||||
async def wait_for_stream_position(
|
async def wait_for_stream_position(
|
||||||
self, instance_name: str, stream_name: str, position: int
|
self, instance_name: str, stream_name: str, position: int
|
||||||
):
|
) -> None:
|
||||||
"""Wait until this instance has received updates up to and including
|
"""Wait until this instance has received updates up to and including
|
||||||
the given stream position.
|
the given stream position.
|
||||||
"""
|
"""
|
||||||
@ -304,7 +308,7 @@ class ReplicationDataHandler:
|
|||||||
"Finished waiting for repl stream %r to reach %s", stream_name, position
|
"Finished waiting for repl stream %r to reach %s", stream_name, position
|
||||||
)
|
)
|
||||||
|
|
||||||
def stop_pusher(self, user_id, app_id, pushkey):
|
def stop_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
|
||||||
if not self._notify_pushers:
|
if not self._notify_pushers:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -316,13 +320,13 @@ class ReplicationDataHandler:
|
|||||||
logger.info("Stopping pusher %r / %r", user_id, key)
|
logger.info("Stopping pusher %r / %r", user_id, key)
|
||||||
pusher.on_stop()
|
pusher.on_stop()
|
||||||
|
|
||||||
async def start_pusher(self, user_id, app_id, pushkey):
|
async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
|
||||||
if not self._notify_pushers:
|
if not self._notify_pushers:
|
||||||
return
|
return
|
||||||
|
|
||||||
key = "%s:%s" % (app_id, pushkey)
|
key = "%s:%s" % (app_id, pushkey)
|
||||||
logger.info("Starting pusher %r / %r", user_id, key)
|
logger.info("Starting pusher %r / %r", user_id, key)
|
||||||
return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
|
await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
|
||||||
|
|
||||||
|
|
||||||
class FederationSenderHandler:
|
class FederationSenderHandler:
|
||||||
@ -353,10 +357,12 @@ class FederationSenderHandler:
|
|||||||
|
|
||||||
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
|
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
|
||||||
|
|
||||||
def wake_destination(self, server: str):
|
def wake_destination(self, server: str) -> None:
|
||||||
self.federation_sender.wake_destination(server)
|
self.federation_sender.wake_destination(server)
|
||||||
|
|
||||||
async def process_replication_rows(self, stream_name, token, rows):
|
async def process_replication_rows(
|
||||||
|
self, stream_name: str, token: int, rows: list
|
||||||
|
) -> None:
|
||||||
# The federation stream contains things that we want to send out, e.g.
|
# The federation stream contains things that we want to send out, e.g.
|
||||||
# presence, typing, etc.
|
# presence, typing, etc.
|
||||||
if stream_name == "federation":
|
if stream_name == "federation":
|
||||||
@ -384,11 +390,12 @@ class FederationSenderHandler:
|
|||||||
for host in hosts:
|
for host in hosts:
|
||||||
self.federation_sender.send_device_messages(host)
|
self.federation_sender.send_device_messages(host)
|
||||||
|
|
||||||
async def _on_new_receipts(self, rows):
|
async def _on_new_receipts(
|
||||||
|
self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow]
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]):
|
rows: new receipts to be processed
|
||||||
new receipts to be processed
|
|
||||||
"""
|
"""
|
||||||
for receipt in rows:
|
for receipt in rows:
|
||||||
# we only want to send on receipts for our own users
|
# we only want to send on receipts for our own users
|
||||||
@ -408,7 +415,7 @@ class FederationSenderHandler:
|
|||||||
)
|
)
|
||||||
await self.federation_sender.send_read_receipt(receipt_info)
|
await self.federation_sender.send_read_receipt(receipt_info)
|
||||||
|
|
||||||
async def update_token(self, token):
|
async def update_token(self, token: int) -> None:
|
||||||
"""Update the record of where we have processed to in the federation stream.
|
"""Update the record of where we have processed to in the federation stream.
|
||||||
|
|
||||||
Called after we have processed a an update received over replication. Sends
|
Called after we have processed a an update received over replication. Sends
|
||||||
@ -428,7 +435,7 @@ class FederationSenderHandler:
|
|||||||
|
|
||||||
run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
|
run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
|
||||||
|
|
||||||
async def _save_and_send_ack(self):
|
async def _save_and_send_ack(self) -> None:
|
||||||
"""Save the current federation position in the database and send an ACK
|
"""Save the current federation position in the database and send an ACK
|
||||||
to master with where we're up to.
|
to master with where we're up to.
|
||||||
"""
|
"""
|
||||||
|
@ -18,12 +18,15 @@ allowed to be sent by which side.
|
|||||||
"""
|
"""
|
||||||
import abc
|
import abc
|
||||||
import logging
|
import logging
|
||||||
from typing import Tuple, Type
|
from typing import Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
|
from synapse.replication.tcp.streams._base import StreamRow
|
||||||
from synapse.util import json_decoder, json_encoder
|
from synapse.util import json_decoder, json_encoder
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="Command")
|
||||||
|
|
||||||
|
|
||||||
class Command(metaclass=abc.ABCMeta):
|
class Command(metaclass=abc.ABCMeta):
|
||||||
"""The base command class.
|
"""The base command class.
|
||||||
@ -38,7 +41,7 @@ class Command(metaclass=abc.ABCMeta):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def from_line(cls, line):
|
def from_line(cls: Type[T], line: str) -> T:
|
||||||
"""Deserialises a line from the wire into this command. `line` does not
|
"""Deserialises a line from the wire into this command. `line` does not
|
||||||
include the command.
|
include the command.
|
||||||
"""
|
"""
|
||||||
@ -49,21 +52,24 @@ class Command(metaclass=abc.ABCMeta):
|
|||||||
prefix.
|
prefix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_logcontext_id(self):
|
def get_logcontext_id(self) -> str:
|
||||||
"""Get a suitable string for the logcontext when processing this command"""
|
"""Get a suitable string for the logcontext when processing this command"""
|
||||||
|
|
||||||
# by default, we just use the command name.
|
# by default, we just use the command name.
|
||||||
return self.NAME
|
return self.NAME
|
||||||
|
|
||||||
|
|
||||||
|
SC = TypeVar("SC", bound="_SimpleCommand")
|
||||||
|
|
||||||
|
|
||||||
class _SimpleCommand(Command):
|
class _SimpleCommand(Command):
|
||||||
"""An implementation of Command whose argument is just a 'data' string."""
|
"""An implementation of Command whose argument is just a 'data' string."""
|
||||||
|
|
||||||
def __init__(self, data):
|
def __init__(self, data: str):
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_line(cls, line):
|
def from_line(cls: Type[SC], line: str) -> SC:
|
||||||
return cls(line)
|
return cls(line)
|
||||||
|
|
||||||
def to_line(self) -> str:
|
def to_line(self) -> str:
|
||||||
@ -109,14 +115,16 @@ class RdataCommand(Command):
|
|||||||
|
|
||||||
NAME = "RDATA"
|
NAME = "RDATA"
|
||||||
|
|
||||||
def __init__(self, stream_name, instance_name, token, row):
|
def __init__(
|
||||||
|
self, stream_name: str, instance_name: str, token: Optional[int], row: StreamRow
|
||||||
|
):
|
||||||
self.stream_name = stream_name
|
self.stream_name = stream_name
|
||||||
self.instance_name = instance_name
|
self.instance_name = instance_name
|
||||||
self.token = token
|
self.token = token
|
||||||
self.row = row
|
self.row = row
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_line(cls, line):
|
def from_line(cls: Type["RdataCommand"], line: str) -> "RdataCommand":
|
||||||
stream_name, instance_name, token, row_json = line.split(" ", 3)
|
stream_name, instance_name, token, row_json = line.split(" ", 3)
|
||||||
return cls(
|
return cls(
|
||||||
stream_name,
|
stream_name,
|
||||||
@ -125,7 +133,7 @@ class RdataCommand(Command):
|
|||||||
json_decoder.decode(row_json),
|
json_decoder.decode(row_json),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_line(self):
|
def to_line(self) -> str:
|
||||||
return " ".join(
|
return " ".join(
|
||||||
(
|
(
|
||||||
self.stream_name,
|
self.stream_name,
|
||||||
@ -135,7 +143,7 @@ class RdataCommand(Command):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_logcontext_id(self):
|
def get_logcontext_id(self) -> str:
|
||||||
return "RDATA-" + self.stream_name
|
return "RDATA-" + self.stream_name
|
||||||
|
|
||||||
|
|
||||||
@ -164,18 +172,20 @@ class PositionCommand(Command):
|
|||||||
|
|
||||||
NAME = "POSITION"
|
NAME = "POSITION"
|
||||||
|
|
||||||
def __init__(self, stream_name, instance_name, prev_token, new_token):
|
def __init__(
|
||||||
|
self, stream_name: str, instance_name: str, prev_token: int, new_token: int
|
||||||
|
):
|
||||||
self.stream_name = stream_name
|
self.stream_name = stream_name
|
||||||
self.instance_name = instance_name
|
self.instance_name = instance_name
|
||||||
self.prev_token = prev_token
|
self.prev_token = prev_token
|
||||||
self.new_token = new_token
|
self.new_token = new_token
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_line(cls, line):
|
def from_line(cls: Type["PositionCommand"], line: str) -> "PositionCommand":
|
||||||
stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
|
stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
|
||||||
return cls(stream_name, instance_name, int(prev_token), int(new_token))
|
return cls(stream_name, instance_name, int(prev_token), int(new_token))
|
||||||
|
|
||||||
def to_line(self):
|
def to_line(self) -> str:
|
||||||
return " ".join(
|
return " ".join(
|
||||||
(
|
(
|
||||||
self.stream_name,
|
self.stream_name,
|
||||||
@ -218,14 +228,14 @@ class ReplicateCommand(Command):
|
|||||||
|
|
||||||
NAME = "REPLICATE"
|
NAME = "REPLICATE"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_line(cls, line):
|
def from_line(cls: Type[T], line: str) -> T:
|
||||||
return cls()
|
return cls()
|
||||||
|
|
||||||
def to_line(self):
|
def to_line(self) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
@ -247,14 +257,16 @@ class UserSyncCommand(Command):
|
|||||||
|
|
||||||
NAME = "USER_SYNC"
|
NAME = "USER_SYNC"
|
||||||
|
|
||||||
def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
|
def __init__(
|
||||||
|
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
|
||||||
|
):
|
||||||
self.instance_id = instance_id
|
self.instance_id = instance_id
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.is_syncing = is_syncing
|
self.is_syncing = is_syncing
|
||||||
self.last_sync_ms = last_sync_ms
|
self.last_sync_ms = last_sync_ms
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_line(cls, line):
|
def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand":
|
||||||
instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
|
instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
|
||||||
|
|
||||||
if state not in ("start", "end"):
|
if state not in ("start", "end"):
|
||||||
@ -262,7 +274,7 @@ class UserSyncCommand(Command):
|
|||||||
|
|
||||||
return cls(instance_id, user_id, state == "start", int(last_sync_ms))
|
return cls(instance_id, user_id, state == "start", int(last_sync_ms))
|
||||||
|
|
||||||
def to_line(self):
|
def to_line(self) -> str:
|
||||||
return " ".join(
|
return " ".join(
|
||||||
(
|
(
|
||||||
self.instance_id,
|
self.instance_id,
|
||||||
@ -286,14 +298,16 @@ class ClearUserSyncsCommand(Command):
|
|||||||
|
|
||||||
NAME = "CLEAR_USER_SYNC"
|
NAME = "CLEAR_USER_SYNC"
|
||||||
|
|
||||||
def __init__(self, instance_id):
|
def __init__(self, instance_id: str):
|
||||||
self.instance_id = instance_id
|
self.instance_id = instance_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_line(cls, line):
|
def from_line(
|
||||||
|
cls: Type["ClearUserSyncsCommand"], line: str
|
||||||
|
) -> "ClearUserSyncsCommand":
|
||||||
return cls(line)
|
return cls(line)
|
||||||
|
|
||||||
def to_line(self):
|
def to_line(self) -> str:
|
||||||
return self.instance_id
|
return self.instance_id
|
||||||
|
|
||||||
|
|
||||||
@ -316,7 +330,9 @@ class FederationAckCommand(Command):
|
|||||||
self.token = token
|
self.token = token
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_line(cls, line: str) -> "FederationAckCommand":
|
def from_line(
|
||||||
|
cls: Type["FederationAckCommand"], line: str
|
||||||
|
) -> "FederationAckCommand":
|
||||||
instance_name, token = line.split(" ")
|
instance_name, token = line.split(" ")
|
||||||
return cls(instance_name, int(token))
|
return cls(instance_name, int(token))
|
||||||
|
|
||||||
@ -334,7 +350,15 @@ class UserIpCommand(Command):
|
|||||||
|
|
||||||
NAME = "USER_IP"
|
NAME = "USER_IP"
|
||||||
|
|
||||||
def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
access_token: str,
|
||||||
|
ip: str,
|
||||||
|
user_agent: str,
|
||||||
|
device_id: str,
|
||||||
|
last_seen: int,
|
||||||
|
):
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.access_token = access_token
|
self.access_token = access_token
|
||||||
self.ip = ip
|
self.ip = ip
|
||||||
@ -343,14 +367,14 @@ class UserIpCommand(Command):
|
|||||||
self.last_seen = last_seen
|
self.last_seen = last_seen
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_line(cls, line):
|
def from_line(cls: Type["UserIpCommand"], line: str) -> "UserIpCommand":
|
||||||
user_id, jsn = line.split(" ", 1)
|
user_id, jsn = line.split(" ", 1)
|
||||||
|
|
||||||
access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
|
access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
|
||||||
|
|
||||||
return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
|
return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
|
||||||
|
|
||||||
def to_line(self):
|
def to_line(self) -> str:
|
||||||
return (
|
return (
|
||||||
self.user_id
|
self.user_id
|
||||||
+ " "
|
+ " "
|
||||||
|
@ -261,7 +261,7 @@ class ReplicationCommandHandler:
|
|||||||
"process-replication-data", self._unsafe_process_queue, stream_name
|
"process-replication-data", self._unsafe_process_queue, stream_name
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _unsafe_process_queue(self, stream_name: str):
|
async def _unsafe_process_queue(self, stream_name: str) -> None:
|
||||||
"""Processes the command queue for the given stream, until it is empty
|
"""Processes the command queue for the given stream, until it is empty
|
||||||
|
|
||||||
Does not check if there is already a thread processing the queue, hence "unsafe"
|
Does not check if there is already a thread processing the queue, hence "unsafe"
|
||||||
@ -294,7 +294,7 @@ class ReplicationCommandHandler:
|
|||||||
# This shouldn't be possible
|
# This shouldn't be possible
|
||||||
raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
|
raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
|
||||||
|
|
||||||
def start_replication(self, hs: "HomeServer"):
|
def start_replication(self, hs: "HomeServer") -> None:
|
||||||
"""Helper method to start a replication connection to the remote server
|
"""Helper method to start a replication connection to the remote server
|
||||||
using TCP.
|
using TCP.
|
||||||
"""
|
"""
|
||||||
@ -345,10 +345,10 @@ class ReplicationCommandHandler:
|
|||||||
"""Get a list of streams that this instances replicates."""
|
"""Get a list of streams that this instances replicates."""
|
||||||
return self._streams_to_replicate
|
return self._streams_to_replicate
|
||||||
|
|
||||||
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
|
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand) -> None:
|
||||||
self.send_positions_to_connection(conn)
|
self.send_positions_to_connection(conn)
|
||||||
|
|
||||||
def send_positions_to_connection(self, conn: IReplicationConnection):
|
def send_positions_to_connection(self, conn: IReplicationConnection) -> None:
|
||||||
"""Send current position of all streams this process is source of to
|
"""Send current position of all streams this process is source of to
|
||||||
the connection.
|
the connection.
|
||||||
"""
|
"""
|
||||||
@ -392,7 +392,7 @@ class ReplicationCommandHandler:
|
|||||||
|
|
||||||
def on_FEDERATION_ACK(
|
def on_FEDERATION_ACK(
|
||||||
self, conn: IReplicationConnection, cmd: FederationAckCommand
|
self, conn: IReplicationConnection, cmd: FederationAckCommand
|
||||||
):
|
) -> None:
|
||||||
federation_ack_counter.inc()
|
federation_ack_counter.inc()
|
||||||
|
|
||||||
if self._federation_sender:
|
if self._federation_sender:
|
||||||
@ -408,7 +408,7 @@ class ReplicationCommandHandler:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _handle_user_ip(self, cmd: UserIpCommand):
|
async def _handle_user_ip(self, cmd: UserIpCommand) -> None:
|
||||||
await self._store.insert_client_ip(
|
await self._store.insert_client_ip(
|
||||||
cmd.user_id,
|
cmd.user_id,
|
||||||
cmd.access_token,
|
cmd.access_token,
|
||||||
@ -421,7 +421,7 @@ class ReplicationCommandHandler:
|
|||||||
assert self._server_notices_sender is not None
|
assert self._server_notices_sender is not None
|
||||||
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
||||||
|
|
||||||
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
|
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand) -> None:
|
||||||
if cmd.instance_name == self._instance_name:
|
if cmd.instance_name == self._instance_name:
|
||||||
# Ignore RDATA that are just our own echoes
|
# Ignore RDATA that are just our own echoes
|
||||||
return
|
return
|
||||||
@ -497,7 +497,7 @@ class ReplicationCommandHandler:
|
|||||||
|
|
||||||
async def on_rdata(
|
async def on_rdata(
|
||||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||||
):
|
) -> None:
|
||||||
"""Called to handle a batch of replication data with a given stream token.
|
"""Called to handle a batch of replication data with a given stream token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -512,7 +512,7 @@ class ReplicationCommandHandler:
|
|||||||
stream_name, instance_name, token, rows
|
stream_name, instance_name, token, rows
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
|
def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand) -> None:
|
||||||
if cmd.instance_name == self._instance_name:
|
if cmd.instance_name == self._instance_name:
|
||||||
# Ignore POSITION that are just our own echoes
|
# Ignore POSITION that are just our own echoes
|
||||||
return
|
return
|
||||||
@ -581,7 +581,7 @@ class ReplicationCommandHandler:
|
|||||||
|
|
||||||
def on_REMOTE_SERVER_UP(
|
def on_REMOTE_SERVER_UP(
|
||||||
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
|
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
|
||||||
):
|
) -> None:
|
||||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||||
self._replication_data_handler.on_remote_server_up(cmd.data)
|
self._replication_data_handler.on_remote_server_up(cmd.data)
|
||||||
|
|
||||||
@ -604,7 +604,7 @@ class ReplicationCommandHandler:
|
|||||||
# between two instances, but that is not currently supported).
|
# between two instances, but that is not currently supported).
|
||||||
self.send_command(cmd, ignore_conn=conn)
|
self.send_command(cmd, ignore_conn=conn)
|
||||||
|
|
||||||
def new_connection(self, connection: IReplicationConnection):
|
def new_connection(self, connection: IReplicationConnection) -> None:
|
||||||
"""Called when we have a new connection."""
|
"""Called when we have a new connection."""
|
||||||
self._connections.append(connection)
|
self._connections.append(connection)
|
||||||
|
|
||||||
@ -631,7 +631,7 @@ class ReplicationCommandHandler:
|
|||||||
UserSyncCommand(self._instance_id, user_id, True, now)
|
UserSyncCommand(self._instance_id, user_id, True, now)
|
||||||
)
|
)
|
||||||
|
|
||||||
def lost_connection(self, connection: IReplicationConnection):
|
def lost_connection(self, connection: IReplicationConnection) -> None:
|
||||||
"""Called when a connection is closed/lost."""
|
"""Called when a connection is closed/lost."""
|
||||||
# we no longer need _streams_by_connection for this connection.
|
# we no longer need _streams_by_connection for this connection.
|
||||||
streams = self._streams_by_connection.pop(connection, None)
|
streams = self._streams_by_connection.pop(connection, None)
|
||||||
@ -653,7 +653,7 @@ class ReplicationCommandHandler:
|
|||||||
|
|
||||||
def send_command(
|
def send_command(
|
||||||
self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
|
self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
|
||||||
):
|
) -> None:
|
||||||
"""Send a command to all connected connections.
|
"""Send a command to all connected connections.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -680,7 +680,7 @@ class ReplicationCommandHandler:
|
|||||||
else:
|
else:
|
||||||
logger.warning("Dropping command as not connected: %r", cmd.NAME)
|
logger.warning("Dropping command as not connected: %r", cmd.NAME)
|
||||||
|
|
||||||
def send_federation_ack(self, token: int):
|
def send_federation_ack(self, token: int) -> None:
|
||||||
"""Ack data for the federation stream. This allows the master to drop
|
"""Ack data for the federation stream. This allows the master to drop
|
||||||
data stored purely in memory.
|
data stored purely in memory.
|
||||||
"""
|
"""
|
||||||
@ -688,7 +688,7 @@ class ReplicationCommandHandler:
|
|||||||
|
|
||||||
def send_user_sync(
|
def send_user_sync(
|
||||||
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
|
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
|
||||||
):
|
) -> None:
|
||||||
"""Poke the master that a user has started/stopped syncing."""
|
"""Poke the master that a user has started/stopped syncing."""
|
||||||
self.send_command(
|
self.send_command(
|
||||||
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
|
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
|
||||||
@ -702,15 +702,15 @@ class ReplicationCommandHandler:
|
|||||||
user_agent: str,
|
user_agent: str,
|
||||||
device_id: str,
|
device_id: str,
|
||||||
last_seen: int,
|
last_seen: int,
|
||||||
):
|
) -> None:
|
||||||
"""Tell the master that the user made a request."""
|
"""Tell the master that the user made a request."""
|
||||||
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
|
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
|
||||||
self.send_command(cmd)
|
self.send_command(cmd)
|
||||||
|
|
||||||
def send_remote_server_up(self, server: str):
|
def send_remote_server_up(self, server: str) -> None:
|
||||||
self.send_command(RemoteServerUpCommand(server))
|
self.send_command(RemoteServerUpCommand(server))
|
||||||
|
|
||||||
def stream_update(self, stream_name: str, token: str, data: Any):
|
def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None:
|
||||||
"""Called when a new update is available to stream to clients.
|
"""Called when a new update is available to stream to clients.
|
||||||
|
|
||||||
We need to check if the client is interested in the stream or not
|
We need to check if the client is interested in the stream or not
|
||||||
|
@ -49,7 +49,7 @@ import fcntl
|
|||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from inspect import isawaitable
|
from inspect import isawaitable
|
||||||
from typing import TYPE_CHECKING, Collection, List, Optional
|
from typing import TYPE_CHECKING, Any, Collection, List, Optional
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
from zope.interface import Interface, implementer
|
from zope.interface import Interface, implementer
|
||||||
@ -123,7 +123,7 @@ class ConnectionStates:
|
|||||||
class IReplicationConnection(Interface):
|
class IReplicationConnection(Interface):
|
||||||
"""An interface for replication connections."""
|
"""An interface for replication connections."""
|
||||||
|
|
||||||
def send_command(cmd: Command):
|
def send_command(cmd: Command) -> None:
|
||||||
"""Send the command down the connection"""
|
"""Send the command down the connection"""
|
||||||
|
|
||||||
|
|
||||||
@ -190,7 +190,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
"replication-conn", self.conn_id
|
"replication-conn", self.conn_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self) -> None:
|
||||||
logger.info("[%s] Connection established", self.id())
|
logger.info("[%s] Connection established", self.id())
|
||||||
|
|
||||||
self.state = ConnectionStates.ESTABLISHED
|
self.state = ConnectionStates.ESTABLISHED
|
||||||
@ -207,11 +207,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
|
|
||||||
# Always send the initial PING so that the other side knows that they
|
# Always send the initial PING so that the other side knows that they
|
||||||
# can time us out.
|
# can time us out.
|
||||||
self.send_command(PingCommand(self.clock.time_msec()))
|
self.send_command(PingCommand(str(self.clock.time_msec())))
|
||||||
|
|
||||||
self.command_handler.new_connection(self)
|
self.command_handler.new_connection(self)
|
||||||
|
|
||||||
def send_ping(self):
|
def send_ping(self) -> None:
|
||||||
"""Periodically sends a ping and checks if we should close the connection
|
"""Periodically sends a ping and checks if we should close the connection
|
||||||
due to the other side timing out.
|
due to the other side timing out.
|
||||||
"""
|
"""
|
||||||
@ -226,7 +226,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
self.transport.abortConnection()
|
self.transport.abortConnection()
|
||||||
else:
|
else:
|
||||||
if now - self.last_sent_command >= PING_TIME:
|
if now - self.last_sent_command >= PING_TIME:
|
||||||
self.send_command(PingCommand(now))
|
self.send_command(PingCommand(str(now)))
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.received_ping
|
self.received_ping
|
||||||
@ -239,12 +239,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
)
|
)
|
||||||
self.send_error("ping timeout")
|
self.send_error("ping timeout")
|
||||||
|
|
||||||
def lineReceived(self, line: bytes):
|
def lineReceived(self, line: bytes) -> None:
|
||||||
"""Called when we've received a line"""
|
"""Called when we've received a line"""
|
||||||
with PreserveLoggingContext(self._logging_context):
|
with PreserveLoggingContext(self._logging_context):
|
||||||
self._parse_and_dispatch_line(line)
|
self._parse_and_dispatch_line(line)
|
||||||
|
|
||||||
def _parse_and_dispatch_line(self, line: bytes):
|
def _parse_and_dispatch_line(self, line: bytes) -> None:
|
||||||
if line.strip() == "":
|
if line.strip() == "":
|
||||||
# Ignore blank lines
|
# Ignore blank lines
|
||||||
return
|
return
|
||||||
@ -309,24 +309,24 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
if not handled:
|
if not handled:
|
||||||
logger.warning("Unhandled command: %r", cmd)
|
logger.warning("Unhandled command: %r", cmd)
|
||||||
|
|
||||||
def close(self):
|
def close(self) -> None:
|
||||||
logger.warning("[%s] Closing connection", self.id())
|
logger.warning("[%s] Closing connection", self.id())
|
||||||
self.time_we_closed = self.clock.time_msec()
|
self.time_we_closed = self.clock.time_msec()
|
||||||
assert self.transport is not None
|
assert self.transport is not None
|
||||||
self.transport.loseConnection()
|
self.transport.loseConnection()
|
||||||
self.on_connection_closed()
|
self.on_connection_closed()
|
||||||
|
|
||||||
def send_error(self, error_string, *args):
|
def send_error(self, error_string: str, *args: Any) -> None:
|
||||||
"""Send an error to remote and close the connection."""
|
"""Send an error to remote and close the connection."""
|
||||||
self.send_command(ErrorCommand(error_string % args))
|
self.send_command(ErrorCommand(error_string % args))
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def send_command(self, cmd, do_buffer=True):
|
def send_command(self, cmd: Command, do_buffer: bool = True) -> None:
|
||||||
"""Send a command if connection has been established.
|
"""Send a command if connection has been established.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cmd (Command)
|
cmd
|
||||||
do_buffer (bool): Whether to buffer the message or always attempt
|
do_buffer: Whether to buffer the message or always attempt
|
||||||
to send the command. This is mostly used to send an error
|
to send the command. This is mostly used to send an error
|
||||||
message if we're about to close the connection due our buffers
|
message if we're about to close the connection due our buffers
|
||||||
becoming full.
|
becoming full.
|
||||||
@ -357,7 +357,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
|
|
||||||
self.last_sent_command = self.clock.time_msec()
|
self.last_sent_command = self.clock.time_msec()
|
||||||
|
|
||||||
def _queue_command(self, cmd):
|
def _queue_command(self, cmd: Command) -> None:
|
||||||
"""Queue the command until the connection is ready to write to again."""
|
"""Queue the command until the connection is ready to write to again."""
|
||||||
logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
|
logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
|
||||||
self.pending_commands.append(cmd)
|
self.pending_commands.append(cmd)
|
||||||
@ -370,20 +370,20 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False)
|
self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False)
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def _send_pending_commands(self):
|
def _send_pending_commands(self) -> None:
|
||||||
"""Send any queued commandes"""
|
"""Send any queued commandes"""
|
||||||
pending = self.pending_commands
|
pending = self.pending_commands
|
||||||
self.pending_commands = []
|
self.pending_commands = []
|
||||||
for cmd in pending:
|
for cmd in pending:
|
||||||
self.send_command(cmd)
|
self.send_command(cmd)
|
||||||
|
|
||||||
def on_PING(self, line):
|
def on_PING(self, cmd: PingCommand) -> None:
|
||||||
self.received_ping = True
|
self.received_ping = True
|
||||||
|
|
||||||
def on_ERROR(self, cmd):
|
def on_ERROR(self, cmd: ErrorCommand) -> None:
|
||||||
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
|
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
|
||||||
|
|
||||||
def pauseProducing(self):
|
def pauseProducing(self) -> None:
|
||||||
"""This is called when both the kernel send buffer and the twisted
|
"""This is called when both the kernel send buffer and the twisted
|
||||||
tcp connection send buffers have become full.
|
tcp connection send buffers have become full.
|
||||||
|
|
||||||
@ -394,26 +394,26 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
logger.info("[%s] Pause producing", self.id())
|
logger.info("[%s] Pause producing", self.id())
|
||||||
self.state = ConnectionStates.PAUSED
|
self.state = ConnectionStates.PAUSED
|
||||||
|
|
||||||
def resumeProducing(self):
|
def resumeProducing(self) -> None:
|
||||||
"""The remote has caught up after we started buffering!"""
|
"""The remote has caught up after we started buffering!"""
|
||||||
logger.info("[%s] Resume producing", self.id())
|
logger.info("[%s] Resume producing", self.id())
|
||||||
self.state = ConnectionStates.ESTABLISHED
|
self.state = ConnectionStates.ESTABLISHED
|
||||||
self._send_pending_commands()
|
self._send_pending_commands()
|
||||||
|
|
||||||
def stopProducing(self):
|
def stopProducing(self) -> None:
|
||||||
"""We're never going to send any more data (normally because either
|
"""We're never going to send any more data (normally because either
|
||||||
we or the remote has closed the connection)
|
we or the remote has closed the connection)
|
||||||
"""
|
"""
|
||||||
logger.info("[%s] Stop producing", self.id())
|
logger.info("[%s] Stop producing", self.id())
|
||||||
self.on_connection_closed()
|
self.on_connection_closed()
|
||||||
|
|
||||||
def connectionLost(self, reason):
|
def connectionLost(self, reason: Failure) -> None: # type: ignore[override]
|
||||||
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
|
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
|
||||||
if isinstance(reason, Failure):
|
if isinstance(reason, Failure):
|
||||||
assert reason.type is not None
|
assert reason.type is not None
|
||||||
connection_close_counter.labels(reason.type.__name__).inc()
|
connection_close_counter.labels(reason.type.__name__).inc()
|
||||||
else:
|
else:
|
||||||
connection_close_counter.labels(reason.__class__.__name__).inc()
|
connection_close_counter.labels(reason.__class__.__name__).inc() # type: ignore[unreachable]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Remove us from list of connections to be monitored
|
# Remove us from list of connections to be monitored
|
||||||
@ -427,7 +427,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
|
|
||||||
self.on_connection_closed()
|
self.on_connection_closed()
|
||||||
|
|
||||||
def on_connection_closed(self):
|
def on_connection_closed(self) -> None:
|
||||||
logger.info("[%s] Connection was closed", self.id())
|
logger.info("[%s] Connection was closed", self.id())
|
||||||
|
|
||||||
self.state = ConnectionStates.CLOSED
|
self.state = ConnectionStates.CLOSED
|
||||||
@ -445,7 +445,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
# the sentinel context is now active, which may not be correct.
|
# the sentinel context is now active, which may not be correct.
|
||||||
# PreserveLoggingContext() will restore the correct logging context.
|
# PreserveLoggingContext() will restore the correct logging context.
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
addr = None
|
addr = None
|
||||||
if self.transport:
|
if self.transport:
|
||||||
addr = str(self.transport.getPeer())
|
addr = str(self.transport.getPeer())
|
||||||
@ -455,10 +455,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|||||||
addr,
|
addr,
|
||||||
)
|
)
|
||||||
|
|
||||||
def id(self):
|
def id(self) -> str:
|
||||||
return "%s-%s" % (self.name, self.conn_id)
|
return "%s-%s" % (self.name, self.conn_id)
|
||||||
|
|
||||||
def lineLengthExceeded(self, line):
|
def lineLengthExceeded(self, line: str) -> None:
|
||||||
"""Called when we receive a line that is above the maximum line length"""
|
"""Called when we receive a line that is above the maximum line length"""
|
||||||
self.send_error("Line length exceeded")
|
self.send_error("Line length exceeded")
|
||||||
|
|
||||||
@ -474,11 +474,11 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||||||
|
|
||||||
self.server_name = server_name
|
self.server_name = server_name
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self) -> None:
|
||||||
self.send_command(ServerCommand(self.server_name))
|
self.send_command(ServerCommand(self.server_name))
|
||||||
super().connectionMade()
|
super().connectionMade()
|
||||||
|
|
||||||
def on_NAME(self, cmd):
|
def on_NAME(self, cmd: NameCommand) -> None:
|
||||||
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
|
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
|
||||||
self.name = cmd.data
|
self.name = cmd.data
|
||||||
|
|
||||||
@ -500,19 +500,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|||||||
self.client_name = client_name
|
self.client_name = client_name
|
||||||
self.server_name = server_name
|
self.server_name = server_name
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self) -> None:
|
||||||
self.send_command(NameCommand(self.client_name))
|
self.send_command(NameCommand(self.client_name))
|
||||||
super().connectionMade()
|
super().connectionMade()
|
||||||
|
|
||||||
# Once we've connected subscribe to the necessary streams
|
# Once we've connected subscribe to the necessary streams
|
||||||
self.replicate()
|
self.replicate()
|
||||||
|
|
||||||
def on_SERVER(self, cmd):
|
def on_SERVER(self, cmd: ServerCommand) -> None:
|
||||||
if cmd.data != self.server_name:
|
if cmd.data != self.server_name:
|
||||||
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
|
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
|
||||||
self.send_error("Wrong remote")
|
self.send_error("Wrong remote")
|
||||||
|
|
||||||
def replicate(self):
|
def replicate(self) -> None:
|
||||||
"""Send the subscription request to the server"""
|
"""Send the subscription request to the server"""
|
||||||
logger.info("[%s] Subscribing to replication streams", self.id())
|
logger.info("[%s] Subscribing to replication streams", self.id())
|
||||||
|
|
||||||
@ -529,7 +529,7 @@ pending_commands = LaterGauge(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def transport_buffer_size(protocol):
|
def transport_buffer_size(protocol: BaseReplicationStreamProtocol) -> int:
|
||||||
if protocol.transport:
|
if protocol.transport:
|
||||||
size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen
|
size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen
|
||||||
return size
|
return size
|
||||||
@ -544,7 +544,9 @@ transport_send_buffer = LaterGauge(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def transport_kernel_read_buffer_size(protocol, read=True):
|
def transport_kernel_read_buffer_size(
|
||||||
|
protocol: BaseReplicationStreamProtocol, read: bool = True
|
||||||
|
) -> int:
|
||||||
SIOCINQ = 0x541B
|
SIOCINQ = 0x541B
|
||||||
SIOCOUTQ = 0x5411
|
SIOCOUTQ = 0x5411
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from inspect import isawaitable
|
from inspect import isawaitable
|
||||||
from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import txredisapi
|
import txredisapi
|
||||||
@ -62,7 +62,7 @@ class ConstantProperty(Generic[T, V]):
|
|||||||
def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V:
|
def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V:
|
||||||
return self.constant
|
return self.constant
|
||||||
|
|
||||||
def __set__(self, obj: Optional[T], value: V):
|
def __set__(self, obj: Optional[T], value: V) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -95,7 +95,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
|
|||||||
synapse_stream_name: str
|
synapse_stream_name: str
|
||||||
synapse_outbound_redis_connection: txredisapi.RedisProtocol
|
synapse_outbound_redis_connection: txredisapi.RedisProtocol
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args: Any, **kwargs: Any):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
# a logcontext which we use for processing incoming commands. We declare it as a
|
# a logcontext which we use for processing incoming commands. We declare it as a
|
||||||
@ -108,12 +108,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
|
|||||||
"replication_command_handler"
|
"replication_command_handler"
|
||||||
)
|
)
|
||||||
|
|
||||||
def connectionMade(self):
|
def connectionMade(self) -> None:
|
||||||
logger.info("Connected to redis")
|
logger.info("Connected to redis")
|
||||||
super().connectionMade()
|
super().connectionMade()
|
||||||
run_as_background_process("subscribe-replication", self._send_subscribe)
|
run_as_background_process("subscribe-replication", self._send_subscribe)
|
||||||
|
|
||||||
async def _send_subscribe(self):
|
async def _send_subscribe(self) -> None:
|
||||||
# it's important to make sure that we only send the REPLICATE command once we
|
# it's important to make sure that we only send the REPLICATE command once we
|
||||||
# have successfully subscribed to the stream - otherwise we might miss the
|
# have successfully subscribed to the stream - otherwise we might miss the
|
||||||
# POSITION response sent back by the other end.
|
# POSITION response sent back by the other end.
|
||||||
@ -131,12 +131,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
|
|||||||
# otherside won't know we've connected and so won't issue a REPLICATE.
|
# otherside won't know we've connected and so won't issue a REPLICATE.
|
||||||
self.synapse_handler.send_positions_to_connection(self)
|
self.synapse_handler.send_positions_to_connection(self)
|
||||||
|
|
||||||
def messageReceived(self, pattern: str, channel: str, message: str):
|
def messageReceived(self, pattern: str, channel: str, message: str) -> None:
|
||||||
"""Received a message from redis."""
|
"""Received a message from redis."""
|
||||||
with PreserveLoggingContext(self._logging_context):
|
with PreserveLoggingContext(self._logging_context):
|
||||||
self._parse_and_dispatch_message(message)
|
self._parse_and_dispatch_message(message)
|
||||||
|
|
||||||
def _parse_and_dispatch_message(self, message: str):
|
def _parse_and_dispatch_message(self, message: str) -> None:
|
||||||
if message.strip() == "":
|
if message.strip() == "":
|
||||||
# Ignore blank lines
|
# Ignore blank lines
|
||||||
return
|
return
|
||||||
@ -181,7 +181,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
|
|||||||
"replication-" + cmd.get_logcontext_id(), lambda: res
|
"replication-" + cmd.get_logcontext_id(), lambda: res
|
||||||
)
|
)
|
||||||
|
|
||||||
def connectionLost(self, reason):
|
def connectionLost(self, reason: Failure) -> None: # type: ignore[override]
|
||||||
logger.info("Lost connection to redis")
|
logger.info("Lost connection to redis")
|
||||||
super().connectionLost(reason)
|
super().connectionLost(reason)
|
||||||
self.synapse_handler.lost_connection(self)
|
self.synapse_handler.lost_connection(self)
|
||||||
@ -193,17 +193,17 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
|
|||||||
# the sentinel context is now active, which may not be correct.
|
# the sentinel context is now active, which may not be correct.
|
||||||
# PreserveLoggingContext() will restore the correct logging context.
|
# PreserveLoggingContext() will restore the correct logging context.
|
||||||
|
|
||||||
def send_command(self, cmd: Command):
|
def send_command(self, cmd: Command) -> None:
|
||||||
"""Send a command if connection has been established.
|
"""Send a command if connection has been established.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cmd (Command)
|
cmd: The command to send
|
||||||
"""
|
"""
|
||||||
run_as_background_process(
|
run_as_background_process(
|
||||||
"send-cmd", self._async_send_command, cmd, bg_start_span=False
|
"send-cmd", self._async_send_command, cmd, bg_start_span=False
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_send_command(self, cmd: Command):
|
async def _async_send_command(self, cmd: Command) -> None:
|
||||||
"""Encode a replication command and send it over our outbound connection"""
|
"""Encode a replication command and send it over our outbound connection"""
|
||||||
string = "%s %s" % (cmd.NAME, cmd.to_line())
|
string = "%s %s" % (cmd.NAME, cmd.to_line())
|
||||||
if "\n" in string:
|
if "\n" in string:
|
||||||
@ -259,7 +259,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
|
|||||||
hs.get_clock().looping_call(self._send_ping, 30 * 1000)
|
hs.get_clock().looping_call(self._send_ping, 30 * 1000)
|
||||||
|
|
||||||
@wrap_as_background_process("redis_ping")
|
@wrap_as_background_process("redis_ping")
|
||||||
async def _send_ping(self):
|
async def _send_ping(self) -> None:
|
||||||
for connection in self.pool:
|
for connection in self.pool:
|
||||||
try:
|
try:
|
||||||
await make_deferred_yieldable(connection.ping())
|
await make_deferred_yieldable(connection.ping())
|
||||||
@ -269,13 +269,13 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
|
|||||||
# ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
|
# ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
|
||||||
# it's rubbish. We add our own here.
|
# it's rubbish. We add our own here.
|
||||||
|
|
||||||
def startedConnecting(self, connector: IConnector):
|
def startedConnecting(self, connector: IConnector) -> None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Connecting to redis server %s", format_address(connector.getDestination())
|
"Connecting to redis server %s", format_address(connector.getDestination())
|
||||||
)
|
)
|
||||||
super().startedConnecting(connector)
|
super().startedConnecting(connector)
|
||||||
|
|
||||||
def clientConnectionFailed(self, connector: IConnector, reason: Failure):
|
def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Connection to redis server %s failed: %s",
|
"Connection to redis server %s failed: %s",
|
||||||
format_address(connector.getDestination()),
|
format_address(connector.getDestination()),
|
||||||
@ -283,7 +283,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
|
|||||||
)
|
)
|
||||||
super().clientConnectionFailed(connector, reason)
|
super().clientConnectionFailed(connector, reason)
|
||||||
|
|
||||||
def clientConnectionLost(self, connector: IConnector, reason: Failure):
|
def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Connection to redis server %s lost: %s",
|
"Connection to redis server %s lost: %s",
|
||||||
format_address(connector.getDestination()),
|
format_address(connector.getDestination()),
|
||||||
@ -330,7 +330,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
|
|||||||
|
|
||||||
self.synapse_outbound_redis_connection = outbound_redis_connection
|
self.synapse_outbound_redis_connection = outbound_redis_connection
|
||||||
|
|
||||||
def buildProtocol(self, addr):
|
def buildProtocol(self, addr: IAddress) -> RedisSubscriber:
|
||||||
p = super().buildProtocol(addr)
|
p = super().buildProtocol(addr)
|
||||||
p = cast(RedisSubscriber, p)
|
p = cast(RedisSubscriber, p)
|
||||||
|
|
||||||
|
@ -16,16 +16,18 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
|
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
|
||||||
|
from twisted.internet.interfaces import IAddress
|
||||||
from twisted.internet.protocol import ServerFactory
|
from twisted.internet.protocol import ServerFactory
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.replication.tcp.commands import PositionCommand
|
from synapse.replication.tcp.commands import PositionCommand
|
||||||
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
|
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
|
||||||
from synapse.replication.tcp.streams import EventsStream
|
from synapse.replication.tcp.streams import EventsStream
|
||||||
|
from synapse.replication.tcp.streams._base import StreamRow, Token
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -56,7 +58,7 @@ class ReplicationStreamProtocolFactory(ServerFactory):
|
|||||||
# listener config again or always starting a `ReplicationStreamer`.)
|
# listener config again or always starting a `ReplicationStreamer`.)
|
||||||
hs.get_replication_streamer()
|
hs.get_replication_streamer()
|
||||||
|
|
||||||
def buildProtocol(self, addr):
|
def buildProtocol(self, addr: IAddress) -> ServerReplicationStreamProtocol:
|
||||||
return ServerReplicationStreamProtocol(
|
return ServerReplicationStreamProtocol(
|
||||||
self.server_name, self.clock, self.command_handler
|
self.server_name, self.clock, self.command_handler
|
||||||
)
|
)
|
||||||
@ -105,7 +107,7 @@ class ReplicationStreamer:
|
|||||||
if any(EventsStream.NAME == s.NAME for s in self.streams):
|
if any(EventsStream.NAME == s.NAME for s in self.streams):
|
||||||
self.clock.looping_call(self.on_notifier_poke, 1000)
|
self.clock.looping_call(self.on_notifier_poke, 1000)
|
||||||
|
|
||||||
def on_notifier_poke(self):
|
def on_notifier_poke(self) -> None:
|
||||||
"""Checks if there is actually any new data and sends it to the
|
"""Checks if there is actually any new data and sends it to the
|
||||||
connections if there are.
|
connections if there are.
|
||||||
|
|
||||||
@ -137,7 +139,7 @@ class ReplicationStreamer:
|
|||||||
|
|
||||||
run_as_background_process("replication_notifier", self._run_notifier_loop)
|
run_as_background_process("replication_notifier", self._run_notifier_loop)
|
||||||
|
|
||||||
async def _run_notifier_loop(self):
|
async def _run_notifier_loop(self) -> None:
|
||||||
self.is_looping = True
|
self.is_looping = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -238,7 +240,9 @@ class ReplicationStreamer:
|
|||||||
self.is_looping = False
|
self.is_looping = False
|
||||||
|
|
||||||
|
|
||||||
def _batch_updates(updates):
|
def _batch_updates(
|
||||||
|
updates: List[Tuple[Token, StreamRow]]
|
||||||
|
) -> List[Tuple[Optional[Token], StreamRow]]:
|
||||||
"""Takes a list of updates of form [(token, row)] and sets the token to
|
"""Takes a list of updates of form [(token, row)] and sets the token to
|
||||||
None for all rows where the next row has the same token. This is used to
|
None for all rows where the next row has the same token. This is used to
|
||||||
implement batching.
|
implement batching.
|
||||||
@ -254,7 +258,7 @@ def _batch_updates(updates):
|
|||||||
if not updates:
|
if not updates:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
new_updates = []
|
new_updates: List[Tuple[Optional[Token], StreamRow]] = []
|
||||||
for i, update in enumerate(updates[:-1]):
|
for i, update in enumerate(updates[:-1]):
|
||||||
if update[0] == updates[i + 1][0]:
|
if update[0] == updates[i + 1][0]:
|
||||||
new_updates.append((None, update[1]))
|
new_updates.append((None, update[1]))
|
||||||
|
@ -90,7 +90,7 @@ class Stream:
|
|||||||
ROW_TYPE: Any = None
|
ROW_TYPE: Any = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_row(cls, row: StreamRow):
|
def parse_row(cls, row: StreamRow) -> Any:
|
||||||
"""Parse a row received over replication
|
"""Parse a row received over replication
|
||||||
|
|
||||||
By default, assumes that the row data is an array object and passes its contents
|
By default, assumes that the row data is an array object and passes its contents
|
||||||
@ -139,7 +139,7 @@ class Stream:
|
|||||||
# The token from which we last asked for updates
|
# The token from which we last asked for updates
|
||||||
self.last_token = self.current_token(self.local_instance_name)
|
self.last_token = self.current_token(self.local_instance_name)
|
||||||
|
|
||||||
def discard_updates_and_advance(self):
|
def discard_updates_and_advance(self) -> None:
|
||||||
"""Called when the stream should advance but the updates would be discarded,
|
"""Called when the stream should advance but the updates would be discarded,
|
||||||
e.g. when there are no currently connected workers.
|
e.g. when there are no currently connected workers.
|
||||||
"""
|
"""
|
||||||
@ -200,7 +200,7 @@ def current_token_without_instance(
|
|||||||
return lambda instance_name: current_token()
|
return lambda instance_name: current_token()
|
||||||
|
|
||||||
|
|
||||||
def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
|
def make_http_update_function(hs: "HomeServer", stream_name: str) -> UpdateFunction:
|
||||||
"""Makes a suitable function for use as an `update_function` that queries
|
"""Makes a suitable function for use as an `update_function` that queries
|
||||||
the master process for updates.
|
the master process for updates.
|
||||||
"""
|
"""
|
||||||
|
@ -13,12 +13,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import heapq
|
import heapq
|
||||||
from collections.abc import Iterable
|
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple, Type
|
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from ._base import Stream, StreamUpdateResult, Token
|
from synapse.replication.tcp.streams._base import (
|
||||||
|
Stream,
|
||||||
|
StreamRow,
|
||||||
|
StreamUpdateResult,
|
||||||
|
Token,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@ -58,6 +62,9 @@ class EventsStreamRow:
|
|||||||
data: "BaseEventsStreamRow"
|
data: "BaseEventsStreamRow"
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="BaseEventsStreamRow")
|
||||||
|
|
||||||
|
|
||||||
class BaseEventsStreamRow:
|
class BaseEventsStreamRow:
|
||||||
"""Base class for rows to be sent in the events stream.
|
"""Base class for rows to be sent in the events stream.
|
||||||
|
|
||||||
@ -68,7 +75,7 @@ class BaseEventsStreamRow:
|
|||||||
TypeId: str
|
TypeId: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_data(cls, data):
|
def from_data(cls: Type[T], data: Iterable[Optional[str]]) -> T:
|
||||||
"""Parse the data from the replication stream into a row.
|
"""Parse the data from the replication stream into a row.
|
||||||
|
|
||||||
By default we just call the constructor with the data list as arguments
|
By default we just call the constructor with the data list as arguments
|
||||||
@ -221,7 +228,7 @@ class EventsStream(Stream):
|
|||||||
return updates, upper_limit, limited
|
return updates, upper_limit, limited
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_row(cls, row):
|
def parse_row(cls, row: StreamRow) -> "EventsStreamRow":
|
||||||
(typ, data) = row
|
(typ, data) = cast(Tuple[str, Iterable[Optional[str]]], row)
|
||||||
data = TypeToRow[typ].from_data(data)
|
event_stream_row_data = TypeToRow[typ].from_data(data)
|
||||||
return EventsStreamRow(typ, data)
|
return EventsStreamRow(typ, event_stream_row_data)
|
||||||
|
@ -16,8 +16,7 @@ import itertools
|
|||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
from collections.abc import Iterable
|
from typing import Iterable, Optional, Tuple
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from netaddr import valid_ipv6
|
from netaddr import valid_ipv6
|
||||||
|
|
||||||
@ -197,7 +196,7 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
|
|||||||
"""If iterable has maxitems or fewer, return the stringification of a list
|
"""If iterable has maxitems or fewer, return the stringification of a list
|
||||||
containing those items.
|
containing those items.
|
||||||
|
|
||||||
Otherwise, return the stringification of a a list with the first maxitems items,
|
Otherwise, return the stringification of a list with the first maxitems items,
|
||||||
followed by "...".
|
followed by "...".
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from twisted.internet.address import IPv4Address
|
||||||
from twisted.internet.protocol import Protocol
|
from twisted.internet.protocol import Protocol
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
@ -53,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||||||
server_factory = ReplicationStreamProtocolFactory(hs)
|
server_factory = ReplicationStreamProtocolFactory(hs)
|
||||||
self.streamer = hs.get_replication_streamer()
|
self.streamer = hs.get_replication_streamer()
|
||||||
self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
|
self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
|
||||||
None
|
IPv4Address("TCP", "127.0.0.1", 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make a new HomeServer object for the worker
|
# Make a new HomeServer object for the worker
|
||||||
@ -345,7 +346,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
|||||||
self.clock,
|
self.clock,
|
||||||
repl_handler,
|
repl_handler,
|
||||||
)
|
)
|
||||||
server = self.server_factory.buildProtocol(None)
|
server = self.server_factory.buildProtocol(
|
||||||
|
IPv4Address("TCP", "127.0.0.1", 0)
|
||||||
|
)
|
||||||
|
|
||||||
client_transport = FakeTransport(server, self.reactor)
|
client_transport = FakeTransport(server, self.reactor)
|
||||||
client.makeConnection(client_transport)
|
client.makeConnection(client_transport)
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
from twisted.internet.address import IPv4Address
|
||||||
from twisted.internet.interfaces import IProtocol
|
from twisted.internet.interfaces import IProtocol
|
||||||
from twisted.test.proto_helpers import StringTransport
|
from twisted.test.proto_helpers import StringTransport
|
||||||
|
|
||||||
@ -29,7 +30,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
|
|||||||
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
|
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
|
||||||
"""Create a new direct TCP replication connection"""
|
"""Create a new direct TCP replication connection"""
|
||||||
|
|
||||||
proto = self.factory.buildProtocol(("127.0.0.1", 0))
|
proto = self.factory.buildProtocol(IPv4Address("TCP", "127.0.0.1", 0))
|
||||||
transport = StringTransport()
|
transport = StringTransport()
|
||||||
proto.makeConnection(transport)
|
proto.makeConnection(transport)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user