mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Fix bug where sync could get stuck when using workers (#17438)
This is because we serialized the token wrong if the instance map contained entries from before the minimum token.
This commit is contained in:
parent
d88ba45db9
commit
df11af14db
1
changelog.d/17438.bugfix
Normal file
1
changelog.d/17438.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix rare bug where `/sync` would break for a user when using workers with multiple stream writers.
|
@ -699,10 +699,17 @@ class SlidingSyncHandler:
|
|||||||
instance_to_max_stream_ordering_map[instance_name] = stream_ordering
|
instance_to_max_stream_ordering_map[instance_name] = stream_ordering
|
||||||
|
|
||||||
# Then assemble the `RoomStreamToken`
|
# Then assemble the `RoomStreamToken`
|
||||||
|
min_stream_pos = min(instance_to_max_stream_ordering_map.values())
|
||||||
membership_snapshot_token = RoomStreamToken(
|
membership_snapshot_token = RoomStreamToken(
|
||||||
# Minimum position in the `instance_map`
|
# Minimum position in the `instance_map`
|
||||||
stream=min(instance_to_max_stream_ordering_map.values()),
|
stream=min_stream_pos,
|
||||||
instance_map=immutabledict(instance_to_max_stream_ordering_map),
|
instance_map=immutabledict(
|
||||||
|
{
|
||||||
|
instance_name: stream_pos
|
||||||
|
for instance_name, stream_pos in instance_to_max_stream_ordering_map.items()
|
||||||
|
if stream_pos > min_stream_pos
|
||||||
|
}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Since we fetched the users room list at some point in time after the from/to
|
# Since we fetched the users room list at some point in time after the from/to
|
||||||
|
@ -20,6 +20,7 @@
|
|||||||
#
|
#
|
||||||
#
|
#
|
||||||
import abc
|
import abc
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import string
|
import string
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -74,6 +75,9 @@ if TYPE_CHECKING:
|
|||||||
from synapse.storage.databases.main import DataStore, PurgeEventsStore
|
from synapse.storage.databases.main import DataStore, PurgeEventsStore
|
||||||
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
|
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Define a state map type from type/state_key to T (usually an event ID or
|
# Define a state map type from type/state_key to T (usually an event ID or
|
||||||
# event)
|
# event)
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@ -454,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
|
|||||||
represented by a default `stream` attribute and a map of instance name to
|
represented by a default `stream` attribute and a map of instance name to
|
||||||
stream position of any writers that are ahead of the default stream
|
stream position of any writers that are ahead of the default stream
|
||||||
position.
|
position.
|
||||||
|
|
||||||
|
The values in `instance_map` must be greater than the `stream` attribute.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
|
stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
|
||||||
@ -468,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
|
|||||||
kw_only=True,
|
kw_only=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __attrs_post_init__(self) -> None:
|
||||||
|
# Enforce that all instances have a value greater than the min stream
|
||||||
|
# position.
|
||||||
|
for i, v in self.instance_map.items():
|
||||||
|
if v <= self.stream:
|
||||||
|
raise ValueError(
|
||||||
|
f"'instance_map' includes a stream position before the main 'stream' attribute. Instance: {i}"
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def parse(cls, store: "DataStore", string: str) -> "Self":
|
async def parse(cls, store: "DataStore", string: str) -> "Self":
|
||||||
@ -494,6 +509,9 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
|
|||||||
for instance in set(self.instance_map).union(other.instance_map)
|
for instance in set(self.instance_map).union(other.instance_map)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Filter out any redundant entries.
|
||||||
|
instance_map = {i: s for i, s in instance_map.items() if s > max_stream}
|
||||||
|
|
||||||
return attr.evolve(
|
return attr.evolve(
|
||||||
self, stream=max_stream, instance_map=immutabledict(instance_map)
|
self, stream=max_stream, instance_map=immutabledict(instance_map)
|
||||||
)
|
)
|
||||||
@ -539,10 +557,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
|
|||||||
def bound_stream_token(self, max_stream: int) -> "Self":
|
def bound_stream_token(self, max_stream: int) -> "Self":
|
||||||
"""Bound the stream positions to a maximum value"""
|
"""Bound the stream positions to a maximum value"""
|
||||||
|
|
||||||
|
min_pos = min(self.stream, max_stream)
|
||||||
return type(self)(
|
return type(self)(
|
||||||
stream=min(self.stream, max_stream),
|
stream=min_pos,
|
||||||
instance_map=immutabledict(
|
instance_map=immutabledict(
|
||||||
{k: min(s, max_stream) for k, s in self.instance_map.items()}
|
{
|
||||||
|
k: min(s, max_stream)
|
||||||
|
for k, s in self.instance_map.items()
|
||||||
|
if min(s, max_stream) > min_pos
|
||||||
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -637,6 +660,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
|
|||||||
"Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
|
"Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
super().__attrs_post_init__()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
|
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
|
||||||
try:
|
try:
|
||||||
@ -651,6 +676,11 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
|
|||||||
|
|
||||||
instance_map = {}
|
instance_map = {}
|
||||||
for part in parts[1:]:
|
for part in parts[1:]:
|
||||||
|
if not part:
|
||||||
|
# Handle tokens of the form `m5~`, which were created by
|
||||||
|
# a bug
|
||||||
|
continue
|
||||||
|
|
||||||
key, value = part.split(".")
|
key, value = part.split(".")
|
||||||
instance_id = int(key)
|
instance_id = int(key)
|
||||||
pos = int(value)
|
pos = int(value)
|
||||||
@ -666,7 +696,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
|
|||||||
except CancelledError:
|
except CancelledError:
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
# We log an exception here as even though this *might* be a client
|
||||||
|
# handing a bad token, its more likely that Synapse returned a bad
|
||||||
|
# token (and we really want to catch those!).
|
||||||
|
logger.exception("Failed to parse stream token: %r", string)
|
||||||
raise SynapseError(400, "Invalid room stream token %r" % (string,))
|
raise SynapseError(400, "Invalid room stream token %r" % (string,))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -713,6 +746,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
|
|||||||
return self.instance_map.get(instance_name, self.stream)
|
return self.instance_map.get(instance_name, self.stream)
|
||||||
|
|
||||||
async def to_string(self, store: "DataStore") -> str:
|
async def to_string(self, store: "DataStore") -> str:
|
||||||
|
"""See class level docstring for information about the format."""
|
||||||
|
|
||||||
if self.topological is not None:
|
if self.topological is not None:
|
||||||
return "t%d-%d" % (self.topological, self.stream)
|
return "t%d-%d" % (self.topological, self.stream)
|
||||||
elif self.instance_map:
|
elif self.instance_map:
|
||||||
@ -727,8 +762,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
|
|||||||
instance_id = await store.get_id_for_instance(name)
|
instance_id = await store.get_id_for_instance(name)
|
||||||
entries.append(f"{instance_id}.{pos}")
|
entries.append(f"{instance_id}.{pos}")
|
||||||
|
|
||||||
encoded_map = "~".join(entries)
|
if entries:
|
||||||
return f"m{self.stream}~{encoded_map}"
|
encoded_map = "~".join(entries)
|
||||||
|
return f"m{self.stream}~{encoded_map}"
|
||||||
|
return f"s{self.stream}"
|
||||||
else:
|
else:
|
||||||
return "s%d" % (self.stream,)
|
return "s%d" % (self.stream,)
|
||||||
|
|
||||||
@ -756,6 +793,11 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
|
|||||||
|
|
||||||
instance_map = {}
|
instance_map = {}
|
||||||
for part in parts[1:]:
|
for part in parts[1:]:
|
||||||
|
if not part:
|
||||||
|
# Handle tokens of the form `m5~`, which were created by
|
||||||
|
# a bug
|
||||||
|
continue
|
||||||
|
|
||||||
key, value = part.split(".")
|
key, value = part.split(".")
|
||||||
instance_id = int(key)
|
instance_id = int(key)
|
||||||
pos = int(value)
|
pos = int(value)
|
||||||
@ -770,10 +812,15 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
|
|||||||
except CancelledError:
|
except CancelledError:
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
# We log an exception here as even though this *might* be a client
|
||||||
|
# handing a bad token, its more likely that Synapse returned a bad
|
||||||
|
# token (and we really want to catch those!).
|
||||||
|
logger.exception("Failed to parse stream token: %r", string)
|
||||||
raise SynapseError(400, "Invalid stream token %r" % (string,))
|
raise SynapseError(400, "Invalid stream token %r" % (string,))
|
||||||
|
|
||||||
async def to_string(self, store: "DataStore") -> str:
|
async def to_string(self, store: "DataStore") -> str:
|
||||||
|
"""See class level docstring for information about the format."""
|
||||||
|
|
||||||
if self.instance_map:
|
if self.instance_map:
|
||||||
entries = []
|
entries = []
|
||||||
for name, pos in self.instance_map.items():
|
for name, pos in self.instance_map.items():
|
||||||
@ -786,8 +833,10 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
|
|||||||
instance_id = await store.get_id_for_instance(name)
|
instance_id = await store.get_id_for_instance(name)
|
||||||
entries.append(f"{instance_id}.{pos}")
|
entries.append(f"{instance_id}.{pos}")
|
||||||
|
|
||||||
encoded_map = "~".join(entries)
|
if entries:
|
||||||
return f"m{self.stream}~{encoded_map}"
|
encoded_map = "~".join(entries)
|
||||||
|
return f"m{self.stream}~{encoded_map}"
|
||||||
|
return str(self.stream)
|
||||||
else:
|
else:
|
||||||
return str(self.stream)
|
return str(self.stream)
|
||||||
|
|
||||||
|
@ -19,9 +19,18 @@
|
|||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
|
||||||
|
from typing import Type
|
||||||
|
from unittest import skipUnless
|
||||||
|
|
||||||
|
from immutabledict import immutabledict
|
||||||
|
from parameterized import parameterized_class
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
|
AbstractMultiWriterStreamToken,
|
||||||
|
MultiWriterStreamToken,
|
||||||
RoomAlias,
|
RoomAlias,
|
||||||
|
RoomStreamToken,
|
||||||
UserID,
|
UserID,
|
||||||
get_domain_from_id,
|
get_domain_from_id,
|
||||||
get_localpart_from_id,
|
get_localpart_from_id,
|
||||||
@ -29,6 +38,7 @@ from synapse.types import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||||
|
|
||||||
|
|
||||||
class IsMineIDTests(unittest.HomeserverTestCase):
|
class IsMineIDTests(unittest.HomeserverTestCase):
|
||||||
@ -127,3 +137,64 @@ class MapUsernameTestCase(unittest.TestCase):
|
|||||||
# this should work with either a unicode or a bytes
|
# this should work with either a unicode or a bytes
|
||||||
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
|
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
|
||||||
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
|
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized_class(
|
||||||
|
("token_type",),
|
||||||
|
[
|
||||||
|
(MultiWriterStreamToken,),
|
||||||
|
(RoomStreamToken,),
|
||||||
|
],
|
||||||
|
class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}",
|
||||||
|
)
|
||||||
|
class MultiWriterTokenTestCase(unittest.HomeserverTestCase):
|
||||||
|
"""Tests for the different types of multi writer tokens."""
|
||||||
|
|
||||||
|
token_type: Type[AbstractMultiWriterStreamToken]
|
||||||
|
|
||||||
|
def test_basic_token(self) -> None:
|
||||||
|
"""Test that a simple stream token can be serialized and unserialized"""
|
||||||
|
store = self.hs.get_datastores().main
|
||||||
|
|
||||||
|
token = self.token_type(stream=5)
|
||||||
|
|
||||||
|
string_token = self.get_success(token.to_string(store))
|
||||||
|
|
||||||
|
if isinstance(token, RoomStreamToken):
|
||||||
|
self.assertEqual(string_token, "s5")
|
||||||
|
else:
|
||||||
|
self.assertEqual(string_token, "5")
|
||||||
|
|
||||||
|
parsed_token = self.get_success(self.token_type.parse(store, string_token))
|
||||||
|
self.assertEqual(parsed_token, token)
|
||||||
|
|
||||||
|
@skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres")
|
||||||
|
def test_instance_map(self) -> None:
|
||||||
|
"""Test for stream token with instance map"""
|
||||||
|
store = self.hs.get_datastores().main
|
||||||
|
|
||||||
|
token = self.token_type(stream=5, instance_map=immutabledict({"foo": 6}))
|
||||||
|
|
||||||
|
string_token = self.get_success(token.to_string(store))
|
||||||
|
self.assertEqual(string_token, "m5~1.6")
|
||||||
|
|
||||||
|
parsed_token = self.get_success(self.token_type.parse(store, string_token))
|
||||||
|
self.assertEqual(parsed_token, token)
|
||||||
|
|
||||||
|
def test_instance_map_assertion(self) -> None:
|
||||||
|
"""Test that we assert values in the instance map are greater than the
|
||||||
|
min stream position"""
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.token_type(stream=5, instance_map=immutabledict({"foo": 4}))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.token_type(stream=5, instance_map=immutabledict({"foo": 5}))
|
||||||
|
|
||||||
|
def test_parse_bad_token(self) -> None:
|
||||||
|
"""Test that we can parse tokens produced by a bug in Synapse of the
|
||||||
|
form `m5~`"""
|
||||||
|
store = self.hs.get_datastores().main
|
||||||
|
|
||||||
|
parsed_token = self.get_success(self.token_type.parse(store, "m5~"))
|
||||||
|
self.assertEqual(parsed_token, self.token_type(stream=5))
|
||||||
|
Loading…
Reference in New Issue
Block a user