Merge branch 'develop' of github.com:matrix-org/synapse into develop

This commit is contained in:
David Robertson 2023-10-24 14:23:19 +01:00
commit c0d2f7649e
No known key found for this signature in database
GPG Key ID: 903ECE108A39DEDD
33 changed files with 787 additions and 461 deletions

1
changelog.d/16471.bugfix Normal file
View File

@ -0,0 +1 @@
Fixed a bug that prevents Grafana from finding the correct datasource. Contributed by @MichaelSasser.

1
changelog.d/16473.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing, exceedingly rare edge case where the first event persisted by a new event persister worker might not be sent down `/sync`.

1
changelog.d/16515.misc Normal file
View File

@ -0,0 +1 @@
Remove duplicate call to mark remote server 'awake' when using a federation sending worker.

1
changelog.d/16521.misc Normal file
View File

@ -0,0 +1 @@
Stop deleting from an unused table.

1
changelog.d/16526.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

1
changelog.d/16530.bugfix Normal file
View File

@ -0,0 +1 @@
Force TLS certificate verification in user registration script.

1
changelog.d/16531.doc Normal file
View File

@ -0,0 +1 @@
Add a sentence to the opentracing docs on how you can have jaeger in a different place than synapse.

1
changelog.d/16539.misc Normal file
View File

@ -0,0 +1 @@
Bump matrix-synapse-ldap3 from 0.2.2 to 0.3.0.

1
changelog.d/16540.bugfix Normal file
View File

@ -0,0 +1 @@
Fix long-standing bug where `/sync` could tightloop after restart when using SQLite.

1
changelog.d/16541.doc Normal file
View File

@ -0,0 +1 @@
Correctly describe the meaning of unspecified rule lists in the [`alias_creation_rules`](https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#alias_creation_rules) and [`room_list_publication_rules`](https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#room_list_publication_rules) config options and improve their descriptions more generally.

File diff suppressed because it is too large Load Diff

View File

@ -51,17 +51,24 @@ will be inserted with that ID.
For any given stream reader (including writers themselves), we may define a per-writer current stream ID: For any given stream reader (including writers themselves), we may define a per-writer current stream ID:
> The current stream ID _for a writer W_ is the largest stream ID such that > A current stream ID _for a writer W_ is the largest stream ID such that
> all transactions added by W with equal or smaller ID have completed. > all transactions added by W with equal or smaller ID have completed.
Similarly, there is a "linear" notion of current stream ID: Similarly, there is a "linear" notion of current stream ID:
> The "linear" current stream ID is the largest stream ID such that > A "linear" current stream ID is the largest stream ID such that
> all facts (added by any writer) with equal or smaller ID have completed. > all facts (added by any writer) with equal or smaller ID have completed.
Because different stream readers A and B learn about new facts at different times, A and B may disagree about current stream IDs. Because different stream readers A and B learn about new facts at different times, A and B may disagree about current stream IDs.
Put differently: we should think of stream readers as being independent of each other, proceeding through a stream of facts at different rates. Put differently: we should think of stream readers as being independent of each other, proceeding through a stream of facts at different rates.
The above definition does not give a unique current stream ID, in fact there can
be a range of current stream IDs. Synapse uses both the minimum and maximum IDs
for different purposes. Most often the maximum is used, as its generally
beneficial for workers to advance their IDs as soon as possible. However, the
minimum is used in situations where e.g. another worker is going to wait until
the stream advances past a position.
**NB.** For both senses of "current", that if a writer opens a transaction that never completes, the current stream ID will never advance beyond that writer's last written stream ID. **NB.** For both senses of "current", that if a writer opens a transaction that never completes, the current stream ID will never advance beyond that writer's last written stream ID.
For single-writer streams, the per-writer current ID and the linear current ID are the same. For single-writer streams, the per-writer current ID and the linear current ID are the same.

View File

@ -51,6 +51,11 @@ docker run -d --name jaeger \
jaegertracing/all-in-one:1 jaegertracing/all-in-one:1
``` ```
By default, Synapse will publish traces to Jaeger on localhost.
If Jaeger is hosted elsewhere, point Synapse to the correct host by setting
`opentracing.jaeger_config.local_agent.reporting_host` [in the Synapse configuration](usage/configuration/config_documentation.md#opentracing-1)
or by setting the `JAEGER_AGENT_HOST` environment variable to the desired address.
Latest documentation is probably at Latest documentation is probably at
https://www.jaegertracing.io/docs/latest/getting-started. https://www.jaegertracing.io/docs/latest/getting-started.

View File

@ -3797,62 +3797,160 @@ enable_room_list_search: false
--- ---
### `alias_creation_rules` ### `alias_creation_rules`
The `alias_creation_rules` option controls who is allowed to create aliases The `alias_creation_rules` option allows server admins to prevent unwanted
on this server. alias creation on this server.
The format of this option is a list of rules that contain globs that This setting is an optional list of 0 or more rules. By default, no list is
match against user_id, room_id and the new alias (fully qualified with provided, meaning that all alias creations are permitted.
server name). The action in the first rule that matches is taken,
which can currently either be "allow" or "deny".
Missing user_id/room_id/alias fields default to "*". Otherwise, requests to create aliases are matched against each rule in order.
The first rule that matches decides if the request is allowed or denied. If no
rule matches, the request is denied. In particular, this means that configuring
an empty list of rules will deny every alias creation request.
If no rules match the request is denied. An empty list means no one Each rule is a YAML object containing four fields, each of which is an optional string:
can create aliases.
Options for the rules include: * `user_id`: a glob pattern that matches against the creator of the alias.
* `user_id`: Matches against the creator of the alias. Defaults to "*". * `alias`: a glob pattern that matches against the alias being created.
* `alias`: Matches against the alias being created. Defaults to "*". * `room_id`: a glob pattern that matches against the room ID the alias is being pointed at.
* `room_id`: Matches against the room ID the alias is being pointed at. Defaults to "*" * `action`: either `allow` or `deny`. What to do with the request if the rule matches. Defaults to `allow`.
* `action`: Whether to "allow" or "deny" the request if the rule matches. Defaults to allow.
Each of the glob patterns is optional, defaulting to `*` ("match anything").
Note that the patterns match against fully qualified IDs, e.g. against
`@alice:example.com`, `#room:example.com` and `!abcdefghijk:example.com` instead
of `alice`, `room` and `abcedgghijk`.
Example configuration: Example configuration:
```yaml ```yaml
# No rule list specified. All alias creations are allowed.
# This is the default behaviour.
alias_creation_rules: alias_creation_rules:
- user_id: "bad_user"
alias: "spammy_alias"
room_id: "*"
action: deny
``` ```
```yaml
# A list of one rule which allows everything.
# This has the same effect as the previous example.
alias_creation_rules:
- "action": "allow"
```
```yaml
# An empty list of rules. All alias creations are denied.
alias_creation_rules: []
```
```yaml
# A list of one rule which denies everything.
# This has the same effect as the previous example.
alias_creation_rules:
- "action": "deny"
```
```yaml
# Prevent a specific user from creating aliases.
# Allow other users to create any alias
alias_creation_rules:
- user_id: "@bad_user:example.com"
action: deny
- action: allow
```
```yaml
# Prevent aliases being created which point to a specific room.
alias_creation_rules:
- room_id: "!forbiddenRoom:example.com"
action: deny
- action: allow
```
--- ---
### `room_list_publication_rules` ### `room_list_publication_rules`
The `room_list_publication_rules` option controls who can publish and The `room_list_publication_rules` option allows server admins to prevent
which rooms can be published in the public room list. unwanted entries from being published in the public room list.
The format of this option is the same as that for The format of this option is the same as that for
`alias_creation_rules`. [`alias_creation_rules`](#alias_creation_rules): an optional list of 0 or more
rules. By default, no list is provided, meaning that all rooms may be
published to the room list.
If the room has one or more aliases associated with it, only one of Otherwise, requests to publish a room are matched against each rule in order.
the aliases needs to match the alias rule. If there are no aliases The first rule that matches decides if the request is allowed or denied. If no
then only rules with `alias: *` match. rule matches, the request is denied. In particular, this means that configuring
an empty list of rules will deny every alias creation request.
If no rules match the request is denied. An empty list means no one Each rule is a YAML object containing four fields, each of which is an optional string:
can publish rooms.
* `user_id`: a glob pattern that matches against the user publishing the room.
* `alias`: a glob pattern that matches against one of published room's aliases.
- If the room has no aliases, the alias match fails unless `alias` is unspecified or `*`.
- If the room has exactly one alias, the alias match succeeds if the `alias` pattern matches that alias.
- If the room has two or more aliases, the alias match succeeds if the pattern matches at least one of the aliases.
* `room_id`: a glob pattern that matches against the room ID of the room being published.
* `action`: either `allow` or `deny`. What to do with the request if the rule matches. Defaults to `allow`.
Each of the glob patterns is optional, defaulting to `*` ("match anything").
Note that the patterns match against fully qualified IDs, e.g. against
`@alice:example.com`, `#room:example.com` and `!abcdefghijk:example.com` instead
of `alice`, `room` and `abcedgghijk`.
Options for the rules include:
* `user_id`: Matches against the creator of the alias. Defaults to "*".
* `alias`: Matches against any current local or canonical aliases associated with the room. Defaults to "*".
* `room_id`: Matches against the room ID being published. Defaults to "*".
* `action`: Whether to "allow" or "deny" the request if the rule matches. Defaults to allow.
Example configuration: Example configuration:
```yaml ```yaml
# No rule list specified. Anyone may publish any room to the public list.
# This is the default behaviour.
room_list_publication_rules: room_list_publication_rules:
- user_id: "*" ```
alias: "*"
room_id: "*" ```yaml
action: allow # A list of one rule which allows everything.
# This has the same effect as the previous example.
room_list_publication_rules:
- "action": "allow"
```
```yaml
# An empty list of rules. No-one may publish to the room list.
room_list_publication_rules: []
```
```yaml
# A list of one rule which denies everything.
# This has the same effect as the previous example.
room_list_publication_rules:
- "action": "deny"
```
```yaml
# Prevent a specific user from publishing rooms.
# Allow other users to publish anything.
room_list_publication_rules:
- user_id: "@bad_user:example.com"
action: deny
- action: allow
```
```yaml
# Prevent publication of a specific room.
room_list_publication_rules:
- room_id: "!forbiddenRoom:example.com"
action: deny
- action: allow
```
```yaml
# Prevent publication of rooms with at least one alias containing the word "potato".
room_list_publication_rules:
- alias: "#*potato*:example.com"
action: deny
- action: allow
``` ```
--- ---

8
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
[[package]] [[package]]
name = "alabaster" name = "alabaster"
@ -1337,13 +1337,13 @@ test = ["aiounittest", "tox", "twisted"]
[[package]] [[package]]
name = "matrix-synapse-ldap3" name = "matrix-synapse-ldap3"
version = "0.2.2" version = "0.3.0"
description = "An LDAP3 auth provider for Synapse" description = "An LDAP3 auth provider for Synapse"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "matrix-synapse-ldap3-0.2.2.tar.gz", hash = "sha256:b388d95693486eef69adaefd0fd9e84463d52fe17b0214a00efcaa669b73cb74"}, {file = "matrix-synapse-ldap3-0.3.0.tar.gz", hash = "sha256:8bb6517173164d4b9cc44f49de411d8cebdb2e705d5dd1ea1f38733c4a009e1d"},
{file = "matrix_synapse_ldap3-0.2.2-py3-none-any.whl", hash = "sha256:66ee4c85d7952c6c27fd04c09cdfdf4847b8e8b7d6a7ada6ba1100013bda060f"}, {file = "matrix_synapse_ldap3-0.3.0-py3-none-any.whl", hash = "sha256:8b4d701f8702551e98cc1d8c20dbed532de5613584c08d0df22de376ba99159d"},
] ]
[package.dependencies] [package.dependencies]

View File

@ -50,7 +50,7 @@ def request_registration(
url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),) url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),)
# Get the nonce # Get the nonce
r = requests.get(url, verify=False) r = requests.get(url)
if r.status_code != 200: if r.status_code != 200:
_print("ERROR! Received %d %s" % (r.status_code, r.reason)) _print("ERROR! Received %d %s" % (r.status_code, r.reason))
@ -88,7 +88,7 @@ def request_registration(
} }
_print("Sending registration request...") _print("Sending registration request...")
r = requests.post(url, json=data, verify=False) r = requests.post(url, json=data)
if r.status_code != 200: if r.status_code != 200:
_print("ERROR! Received %d %s" % (r.status_code, r.reason)) _print("ERROR! Received %d %s" % (r.status_code, r.reason))

View File

@ -238,7 +238,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
data[_STREAM_POSITION_KEY] = { data[_STREAM_POSITION_KEY] = {
"streams": { "streams": {
stream.NAME: stream.current_token(local_instance_name) stream.NAME: stream.minimal_local_current_token()
for stream in streams for stream in streams
}, },
"instance_name": local_instance_name, "instance_name": local_instance_name,

View File

@ -279,14 +279,6 @@ class ReplicationDataHandler:
# may be streaming. # may be streaming.
self.notifier.notify_replication() self.notifier.notify_replication()
def on_remote_server_up(self, server: str) -> None:
"""Called when get a new REMOTE_SERVER_UP command."""
# Let's wake up the transaction queue for the server in case we have
# pending stuff to send to it.
if self.send_handler:
self.send_handler.wake_destination(server)
async def wait_for_stream_position( async def wait_for_stream_position(
self, self,
instance_name: str, instance_name: str,
@ -405,9 +397,6 @@ class FederationSenderHandler:
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
def wake_destination(self, server: str) -> None:
self.federation_sender.wake_destination(server)
async def process_replication_rows( async def process_replication_rows(
self, stream_name: str, token: int, rows: list self, stream_name: str, token: int, rows: list
) -> None: ) -> None:

View File

@ -657,8 +657,6 @@ class ReplicationCommandHandler:
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
) -> None: ) -> None:
"""Called when get a new REMOTE_SERVER_UP command.""" """Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
self._notifier.notify_remote_server_up(cmd.data) self._notifier.notify_remote_server_up(cmd.data)
def on_LOCK_RELEASED( def on_LOCK_RELEASED(

View File

@ -33,6 +33,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -107,22 +108,10 @@ class Stream:
def __init__( def __init__(
self, self,
local_instance_name: str, local_instance_name: str,
current_token_function: Callable[[str], Token],
update_function: UpdateFunction, update_function: UpdateFunction,
): ):
"""Instantiate a Stream """Instantiate a Stream
`current_token_function` and `update_function` are callbacks which
should be implemented by subclasses.
`current_token_function` takes an instance name, which is a writer to
the stream, and returns the position in the stream of the writer (as
viewed from the current process). On the writer process this is where
the writer has successfully written up to, whereas on other processes
this is the position which we have received updates up to over
replication. (Note that most streams have a single writer and so their
implementations ignore the instance name passed in).
`update_function` is called to get updates for this stream between a `update_function` is called to get updates for this stream between a
pair of stream tokens. See the `UpdateFunction` type definition for more pair of stream tokens. See the `UpdateFunction` type definition for more
info. info.
@ -133,12 +122,28 @@ class Stream:
update_function: callback go get stream updates, as above update_function: callback go get stream updates, as above
""" """
self.local_instance_name = local_instance_name self.local_instance_name = local_instance_name
self.current_token = current_token_function
self.update_function = update_function self.update_function = update_function
# The token from which we last asked for updates # The token from which we last asked for updates
self.last_token = self.current_token(self.local_instance_name) self.last_token = self.current_token(self.local_instance_name)
def current_token(self, instance_name: str) -> Token:
"""This takes an instance name, which is a writer to
the stream, and returns the position in the stream of the writer (as
viewed from the current process).
"""
# We can't make this an abstract class as it makes mypy unhappy.
raise NotImplementedError()
def minimal_local_current_token(self) -> Token:
"""Tries to return a minimal current token for the local instance,
i.e. for writers this would be the last successful write.
If local instance is not a writer (or has written yet) then falls back
to returning the normal "current token".
"""
raise NotImplementedError()
def discard_updates_and_advance(self) -> None: def discard_updates_and_advance(self) -> None:
"""Called when the stream should advance but the updates would be discarded, """Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers. e.g. when there are no currently connected workers.
@ -190,6 +195,25 @@ class Stream:
return updates, upto_token, limited return updates, upto_token, limited
class _StreamFromIdGen(Stream):
"""Helper class for simple streams that use a stream ID generator"""
def __init__(
self,
local_instance_name: str,
update_function: UpdateFunction,
stream_id_gen: "AbstractStreamIdGenerator",
):
self._stream_id_gen = stream_id_gen
super().__init__(local_instance_name, update_function)
def current_token(self, instance_name: str) -> Token:
return self._stream_id_gen.get_current_token_for_writer(instance_name)
def minimal_local_current_token(self) -> Token:
return self._stream_id_gen.get_minimal_local_current_token()
def current_token_without_instance( def current_token_without_instance(
current_token: Callable[[], int] current_token: Callable[[], int]
) -> Callable[[str], int]: ) -> Callable[[str], int]:
@ -242,17 +266,21 @@ class BackfillStream(Stream):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
self._current_token,
self.store.get_all_new_backfill_event_rows, self.store.get_all_new_backfill_event_rows,
) )
def _current_token(self, instance_name: str) -> int: def current_token(self, instance_name: str) -> Token:
# The backfill stream over replication operates on *positive* numbers, # The backfill stream over replication operates on *positive* numbers,
# which means we need to negate it. # which means we need to negate it.
return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name) return -self.store._backfill_id_gen.get_current_token_for_writer(instance_name)
def minimal_local_current_token(self) -> Token:
# The backfill stream over replication operates on *positive* numbers,
# which means we need to negate it.
return -self.store._backfill_id_gen.get_minimal_local_current_token()
class PresenceStream(Stream):
class PresenceStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class PresenceStreamRow: class PresenceStreamRow:
user_id: str user_id: str
@ -283,9 +311,7 @@ class PresenceStream(Stream):
update_function = make_http_update_function(hs, self.NAME) update_function = make_http_update_function(hs, self.NAME)
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(), update_function, store._presence_id_gen
current_token_without_instance(store.get_current_presence_token),
update_function,
) )
@ -305,13 +331,18 @@ class PresenceFederationStream(Stream):
ROW_TYPE = PresenceFederationStreamRow ROW_TYPE = PresenceFederationStreamRow
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
federation_queue = hs.get_presence_handler().get_federation_queue() self._federation_queue = hs.get_presence_handler().get_federation_queue()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
federation_queue.get_current_token, self._federation_queue.get_replication_rows,
federation_queue.get_replication_rows,
) )
def current_token(self, instance_name: str) -> Token:
return self._federation_queue.get_current_token(instance_name)
def minimal_local_current_token(self) -> Token:
return self._federation_queue.get_current_token(self.local_instance_name)
class TypingStream(Stream): class TypingStream(Stream):
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -341,20 +372,25 @@ class TypingStream(Stream):
update_function: Callable[ update_function: Callable[
[str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]] [str, int, int, int], Awaitable[Tuple[List[Tuple[int, Any]], int, bool]]
] = typing_writer_handler.get_all_typing_updates ] = typing_writer_handler.get_all_typing_updates
current_token_function = typing_writer_handler.get_current_token self.current_token_function = typing_writer_handler.get_current_token
else: else:
# Query the typing writer process # Query the typing writer process
update_function = make_http_update_function(hs, self.NAME) update_function = make_http_update_function(hs, self.NAME)
current_token_function = hs.get_typing_handler().get_current_token self.current_token_function = hs.get_typing_handler().get_current_token
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(current_token_function),
update_function, update_function,
) )
def current_token(self, instance_name: str) -> Token:
return self.current_token_function()
class ReceiptsStream(Stream): def minimal_local_current_token(self) -> Token:
return self.current_token_function()
class ReceiptsStream(_StreamFromIdGen):
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class ReceiptsStreamRow: class ReceiptsStreamRow:
room_id: str room_id: str
@ -371,12 +407,12 @@ class ReceiptsStream(Stream):
store = hs.get_datastores().main store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(store.get_max_receipt_stream_id),
store.get_all_updated_receipts, store.get_all_updated_receipts,
store._receipts_id_gen,
) )
class PushRulesStream(Stream): class PushRulesStream(_StreamFromIdGen):
"""A user has changed their push rules""" """A user has changed their push rules"""
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -387,20 +423,16 @@ class PushRulesStream(Stream):
ROW_TYPE = PushRulesStreamRow ROW_TYPE = PushRulesStreamRow
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
self._current_token, store.get_all_push_rule_updates,
self.store.get_all_push_rule_updates, store._push_rules_stream_id_gen,
) )
def _current_token(self, instance_name: str) -> int:
push_rules_token = self.store.get_max_push_rules_stream_id()
return push_rules_token
class PushersStream(_StreamFromIdGen):
class PushersStream(Stream):
"""A user has added/changed/removed a pusher""" """A user has added/changed/removed a pusher"""
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -418,8 +450,8 @@ class PushersStream(Stream):
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(store.get_pushers_stream_token),
store.get_all_updated_pushers_rows, store.get_all_updated_pushers_rows,
store._pushers_id_gen,
) )
@ -447,15 +479,20 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow ROW_TYPE = CachesStreamRow
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
store = hs.get_datastores().main self.store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_cache_stream_token_for_writer, self.store.get_all_updated_caches,
store.get_all_updated_caches,
) )
def current_token(self, instance_name: str) -> Token:
return self.store.get_cache_stream_token_for_writer(instance_name)
class DeviceListsStream(Stream): def minimal_local_current_token(self) -> Token:
return self.current_token(self.local_instance_name)
class DeviceListsStream(_StreamFromIdGen):
"""Either a user has updated their devices or a remote server needs to be """Either a user has updated their devices or a remote server needs to be
told about a device update. told about a device update.
""" """
@ -473,8 +510,8 @@ class DeviceListsStream(Stream):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(self.store.get_device_stream_token),
self._update_function, self._update_function,
self.store._device_list_id_gen,
) )
async def _update_function( async def _update_function(
@ -525,7 +562,7 @@ class DeviceListsStream(Stream):
return updates, upper_limit_token, devices_limited or signatures_limited return updates, upper_limit_token, devices_limited or signatures_limited
class ToDeviceStream(Stream): class ToDeviceStream(_StreamFromIdGen):
"""New to_device messages for a client""" """New to_device messages for a client"""
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -539,12 +576,12 @@ class ToDeviceStream(Stream):
store = hs.get_datastores().main store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(store.get_to_device_stream_token),
store.get_all_new_device_messages, store.get_all_new_device_messages,
store._device_inbox_id_gen,
) )
class AccountDataStream(Stream): class AccountDataStream(_StreamFromIdGen):
"""Global or per room account data was changed""" """Global or per room account data was changed"""
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -560,8 +597,8 @@ class AccountDataStream(Stream):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(self.store.get_max_account_data_stream_id),
self._update_function, self._update_function,
self.store._account_data_id_gen,
) )
async def _update_function( async def _update_function(

View File

@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
import attr import attr
from synapse.replication.tcp.streams._base import ( from synapse.replication.tcp.streams._base import (
Stream,
StreamRow, StreamRow,
StreamUpdateResult, StreamUpdateResult,
Token, Token,
_StreamFromIdGen,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@ -139,7 +139,7 @@ _EventRows: Tuple[Type[BaseEventsStreamRow], ...] = (
TypeToRow = {Row.TypeId: Row for Row in _EventRows} TypeToRow = {Row.TypeId: Row for Row in _EventRows}
class EventsStream(Stream): class EventsStream(_StreamFromIdGen):
"""We received a new event, or an event went from being an outlier to not""" """We received a new event, or an event went from being an outlier to not"""
NAME = "events" NAME = "events"
@ -147,9 +147,7 @@ class EventsStream(Stream):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main self._store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(), self._update_function, self._store._stream_id_gen
self._store._stream_id_gen.get_current_token_for_writer,
self._update_function,
) )
async def _update_function( async def _update_function(

View File

@ -18,6 +18,7 @@ import attr
from synapse.replication.tcp.streams._base import ( from synapse.replication.tcp.streams._base import (
Stream, Stream,
Token,
current_token_without_instance, current_token_without_instance,
make_http_update_function, make_http_update_function,
) )
@ -47,7 +48,7 @@ class FederationStream(Stream):
# will be a real FederationSender, which has stubs for current_token and # will be a real FederationSender, which has stubs for current_token and
# get_replication_rows.) # get_replication_rows.)
federation_sender = hs.get_federation_sender() federation_sender = hs.get_federation_sender()
current_token = current_token_without_instance( self.current_token_func = current_token_without_instance(
federation_sender.get_current_token federation_sender.get_current_token
) )
update_function: Callable[ update_function: Callable[
@ -57,15 +58,21 @@ class FederationStream(Stream):
elif hs.should_send_federation(): elif hs.should_send_federation():
# federation sender: Query master process # federation sender: Query master process
update_function = make_http_update_function(hs, self.NAME) update_function = make_http_update_function(hs, self.NAME)
current_token = self._stub_current_token self.current_token_func = self._stub_current_token
else: else:
# other worker: stub out the update function (we're not interested in # other worker: stub out the update function (we're not interested in
# any updates so when we get a POSITION we do nothing) # any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function update_function = self._stub_update_function
current_token = self._stub_current_token self.current_token_func = self._stub_current_token
super().__init__(hs.get_instance_name(), current_token, update_function) super().__init__(hs.get_instance_name(), update_function)
def current_token(self, instance_name: str) -> Token:
return self.current_token_func(instance_name)
def minimal_local_current_token(self) -> Token:
return self.current_token(self.local_instance_name)
@staticmethod @staticmethod
def _stub_current_token(instance_name: str) -> int: def _stub_current_token(instance_name: str) -> int:

View File

@ -15,7 +15,7 @@ from typing import TYPE_CHECKING
import attr import attr
from synapse.replication.tcp.streams import Stream from synapse.replication.tcp.streams._base import _StreamFromIdGen
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -27,7 +27,7 @@ class UnPartialStatedRoomStreamRow:
room_id: str room_id: str
class UnPartialStatedRoomStream(Stream): class UnPartialStatedRoomStream(_StreamFromIdGen):
""" """
Stream to notify about rooms becoming un-partial-stated; Stream to notify about rooms becoming un-partial-stated;
that is, when the background sync finishes such that we now have full state for that is, when the background sync finishes such that we now have full state for
@ -41,8 +41,8 @@ class UnPartialStatedRoomStream(Stream):
store = hs.get_datastores().main store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_un_partial_stated_rooms_token,
store.get_un_partial_stated_rooms_from_stream, store.get_un_partial_stated_rooms_from_stream,
store._un_partial_stated_rooms_stream_id_gen,
) )
@ -56,7 +56,7 @@ class UnPartialStatedEventStreamRow:
rejection_status_changed: bool rejection_status_changed: bool
class UnPartialStatedEventStream(Stream): class UnPartialStatedEventStream(_StreamFromIdGen):
""" """
Stream to notify about events becoming un-partial-stated. Stream to notify about events becoming un-partial-stated.
""" """
@ -68,6 +68,6 @@ class UnPartialStatedEventStream(Stream):
store = hs.get_datastores().main store = hs.get_datastores().main
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
store.get_un_partial_stated_events_token,
store.get_un_partial_stated_events_from_stream, store.get_un_partial_stated_events_from_stream,
store._un_partial_stated_events_stream_id_gen,
) )

View File

@ -94,7 +94,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
hs.get_replication_notifier(), hs.get_replication_notifier(),
"room_account_data", "room_account_data",
"stream_id", "stream_id",
extra_tables=[("room_tags_revisions", "stream_id")], extra_tables=[
("account_data", "stream_id"),
("room_tags_revisions", "stream_id"),
],
is_writer=self._instance_name in hs.config.worker.writers.account_data, is_writer=self._instance_name in hs.config.worker.writers.account_data,
) )

View File

@ -2095,12 +2095,6 @@ class EventsWorkerStore(SQLBaseStore):
def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None: def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000 one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?
"""
txn.execute(sql, (one_day_ago,))
sql = """ sql = """
DELETE FROM event_txn_id_device_id DELETE FROM event_txn_id_device_id
WHERE inserted_ts < ? WHERE inserted_ts < ?

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
SCHEMA_VERSION = 82 # remember to update the list below when updating SCHEMA_VERSION = 83 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema """Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the This should be incremented whenever the codebase changes its requirements on the
@ -121,6 +121,9 @@ Changes in SCHEMA_VERSION = 81
Changes in SCHEMA_VERSION = 82 Changes in SCHEMA_VERSION = 82
- The insertion_events, insertion_event_extremities, insertion_event_edges, and - The insertion_events, insertion_event_extremities, insertion_event_edges, and
batch_events tables are no longer purged in preparation for their removal. batch_events tables are no longer purged in preparation for their removal.
Changes in SCHEMA_VERSION = 83
- The event_txn_id is no longer used.
""" """

View File

@ -133,6 +133,15 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
""" """
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def get_minimal_local_current_token(self) -> int:
"""Tries to return a minimal current token for the local instance,
i.e. for writers this would be the last successful write.
If local instance is not a writer (or has written yet) then falls back
to returning the normal "current token".
"""
@abc.abstractmethod @abc.abstractmethod
def get_next(self) -> AsyncContextManager[int]: def get_next(self) -> AsyncContextManager[int]:
""" """
@ -312,6 +321,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
def get_current_token_for_writer(self, instance_name: str) -> int: def get_current_token_for_writer(self, instance_name: str) -> int:
return self.get_current_token() return self.get_current_token()
def get_minimal_local_current_token(self) -> int:
return self.get_current_token()
class MultiWriterIdGenerator(AbstractStreamIdGenerator): class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"""Generates and tracks stream IDs for a stream with multiple writers. """Generates and tracks stream IDs for a stream with multiple writers.
@ -408,6 +420,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# The maximum stream ID that we have seen been allocated across any writer. # The maximum stream ID that we have seen been allocated across any writer.
self._max_seen_allocated_stream_id = 1 self._max_seen_allocated_stream_id = 1
# The maximum position of the local instance. This can be higher than
# the corresponding position in `current_positions` table when there are
# no active writes in progress.
self._max_position_of_local_instance = self._max_seen_allocated_stream_id
self._sequence_gen = PostgresSequenceGenerator(sequence_name) self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged. # We check that the table and sequence haven't diverged.
@ -427,6 +444,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._current_positions.values(), default=1 self._current_positions.values(), default=1
) )
# For the case where `stream_positions` is not up to date,
# `_persisted_upto_position` may be higher.
self._max_seen_allocated_stream_id = max(
self._max_seen_allocated_stream_id, self._persisted_upto_position
)
# Bump our local maximum position now that we've loaded things from the
# DB.
self._max_position_of_local_instance = self._max_seen_allocated_stream_id
if not writers: if not writers:
# If there have been no explicit writers given then any instance can # If there have been no explicit writers given then any instance can
# write to the stream. In which case, let's pre-seed our own # write to the stream. In which case, let's pre-seed our own
@ -545,6 +572,14 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if instance == self._instance_name: if instance == self._instance_name:
self._current_positions[instance] = stream_id self._current_positions[instance] = stream_id
if self._writers:
# If we have explicit writers then make sure that each instance has
# a position.
for writer in self._writers:
self._current_positions.setdefault(
writer, self._persisted_upto_position
)
cur.close() cur.close()
def _load_next_id_txn(self, txn: Cursor) -> int: def _load_next_id_txn(self, txn: Cursor) -> int:
@ -688,6 +723,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if new_cur: if new_cur:
curr = self._current_positions.get(self._instance_name, 0) curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, new_cur) self._current_positions[self._instance_name] = max(curr, new_cur)
self._max_position_of_local_instance = max(
curr, new_cur, self._max_position_of_local_instance
)
self._add_persisted_position(next_id) self._add_persisted_position(next_id)
@ -702,10 +740,26 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# persisted up to position. This stops Synapse from doing a full table # persisted up to position. This stops Synapse from doing a full table
# scan when a new writer announces itself over replication. # scan when a new writer announces itself over replication.
with self._lock: with self._lock:
return self._return_factor * self._current_positions.get( if self._instance_name == instance_name:
return self._return_factor * self._max_position_of_local_instance
pos = self._current_positions.get(
instance_name, self._persisted_upto_position instance_name, self._persisted_upto_position
) )
# We want to return the maximum "current token" that we can for a
# writer, this helps ensure that streams progress as fast as
# possible.
pos = max(pos, self._persisted_upto_position)
return self._return_factor * pos
def get_minimal_local_current_token(self) -> int:
with self._lock:
return self._return_factor * self._current_positions.get(
self._instance_name, self._persisted_upto_position
)
def get_positions(self) -> Dict[str, int]: def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map. """Get a copy of the current positon map.
@ -774,6 +828,18 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._persisted_upto_position = max(min_curr, self._persisted_upto_position) self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
# Advance our local max position.
self._max_position_of_local_instance = max(
self._max_position_of_local_instance, self._persisted_upto_position
)
if not self._unfinished_ids and not self._in_flight_fetches:
# If we don't have anything in flight, it's safe to advance to the
# max seen stream ID.
self._max_position_of_local_instance = max(
self._max_seen_allocated_stream_id, self._max_position_of_local_instance
)
# We now iterate through the seen positions, discarding those that are # We now iterate through the seen positions, discarding those that are
# less than the current min positions, and incrementing the min position # less than the current min positions, and incrementing the min position
# if its exactly one greater. # if its exactly one greater.

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import queue import queue
from typing import BinaryIO, Optional, Union, cast from typing import Any, BinaryIO, Optional, Union, cast
from twisted.internet import threads from twisted.internet import threads
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
@ -58,7 +58,9 @@ class BackgroundFileConsumer:
self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue() self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue()
# Deferred that is resolved when finished writing # Deferred that is resolved when finished writing
self._finished_deferred: Optional[Deferred[None]] = None #
# This is really Deferred[None], but mypy doesn't seem to like that.
self._finished_deferred: Optional[Deferred[Any]] = None
# If the _writer thread throws an exception it gets stored here. # If the _writer thread throws an exception it gets stored here.
self._write_exception: Optional[Exception] = None self._write_exception: Optional[Exception] = None
@ -80,9 +82,13 @@ class BackgroundFileConsumer:
self.streaming = streaming self.streaming = streaming
self._finished_deferred = run_in_background( self._finished_deferred = run_in_background(
threads.deferToThreadPool, threads.deferToThreadPool,
self._reactor, # mypy seems to get confused with the chaining of ParamSpec from
self._reactor.getThreadPool(), # run_in_background to deferToThreadPool.
self._writer, #
# For Twisted trunk, ignore arg-type; for Twisted release ignore unused-ignore.
self._reactor, # type: ignore[arg-type,unused-ignore]
self._reactor.getThreadPool(), # type: ignore[arg-type,unused-ignore]
self._writer, # type: ignore[arg-type,unused-ignore]
) )
if not streaming: if not streaming:
self._producer.resumeProducing() self._producer.resumeProducing()

View File

@ -156,6 +156,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
result = self.successResultOf( result = self.successResultOf(
defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias)) defer.ensureDeferred(self.handler.query_room_alias_exists(room_alias))
) )
assert result is not None
self.mock_as_api.query_alias.assert_called_once_with( self.mock_as_api.query_alias.assert_called_once_with(
interested_service, room_alias_str interested_service, room_alias_str

View File

@ -335,7 +335,7 @@ class Deferred__next__Patch:
self._request_number = request_number self._request_number = request_number
self._seen_awaits = seen_awaits self._seen_awaits = seen_awaits
self._original_Deferred___next__ = Deferred.__next__ self._original_Deferred___next__ = Deferred.__next__ # type: ignore[misc,unused-ignore]
# The number of `await`s on `Deferred`s we have seen so far. # The number of `await`s on `Deferred`s we have seen so far.
self.awaits_seen = 0 self.awaits_seen = 0

View File

@ -70,7 +70,7 @@ class FederationClientTests(HomeserverTestCase):
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def do_request() -> Generator["Deferred[object]", object, object]: def do_request() -> Generator["Deferred[Any]", object, object]:
with LoggingContext("one") as context: with LoggingContext("one") as context:
fetch_d = defer.ensureDeferred( fetch_d = defer.ensureDeferred(
self.cl.get_json("testserv:8008", "foo/bar") self.cl.get_json("testserv:8008", "foo/bar")

View File

@ -259,8 +259,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator() id_gen = self._create_id_generator()
# The table is empty so we expect an empty map for positions # The table is empty so we expect the map for positions to have a dummy
self.assertEqual(id_gen.get_positions(), {}) # minimum value.
self.assertEqual(id_gen.get_positions(), {"master": 1})
def test_single_instance(self) -> None: def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled """Test that reads and writes from a single process are handled
@ -349,15 +350,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
first_id_gen = self._create_id_generator("first", writers=["first", "second"]) first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"])
# The first ID gen will notice that it can advance its token to 7 as it
# has no in progress writes...
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
# ... but the second ID gen doesn't know that.
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
# Try allocating a new ID gen and check that we only see position # Try allocating a new ID gen and check that we only see position
@ -398,6 +396,56 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen.advance("first", 8) second_id_gen.advance("first", 8)
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9}) self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
def test_multi_instance_empty_row(self) -> None:
"""Test that reads and writes from multiple processes are handled
correctly, when one of the writers starts without any rows.
"""
# Insert some rows for two out of three of the ID gens.
self._insert_rows("first", 3)
self._insert_rows("second", 4)
first_id_gen = self._create_id_generator(
"first", writers=["first", "second", "third"]
)
second_id_gen = self._create_id_generator(
"second", writers=["first", "second", "third"]
)
third_id_gen = self._create_id_generator(
"third", writers=["first", "second", "third"]
)
self.assertEqual(
first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
)
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("third"), 7)
self.assertEqual(
second_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
)
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("third"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
async def _get_next_async() -> None:
async with third_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(
third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
)
self.assertEqual(third_id_gen.get_persisted_upto_position(), 7)
self.get_success(_get_next_async())
self.assertEqual(
third_id_gen.get_positions(), {"first": 3, "second": 7, "third": 8}
)
def test_get_next_txn(self) -> None: def test_get_next_txn(self) -> None:
"""Test that the `get_next_txn` function works correctly.""" """Test that the `get_next_txn` function works correctly."""
@ -600,6 +648,70 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
with self.assertRaises(IncorrectDatabaseSetup): with self.assertRaises(IncorrectDatabaseSetup):
self._create_id_generator("first") self._create_id_generator("first")
def test_minimal_local_token(self) -> None:
self._insert_rows("first", 3)
self._insert_rows("second", 4)
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(second_id_gen.get_minimal_local_current_token(), 7)
def test_current_token_gap(self) -> None:
"""Test that getting the current token for a writer returns the maximal
token when there are no writes.
"""
self._insert_rows("first", 3)
self._insert_rows("second", 4)
first_id_gen = self._create_id_generator(
"first", writers=["first", "second", "third"]
)
second_id_gen = self._create_id_generator(
"second", writers=["first", "second", "third"]
)
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(second_id_gen.get_current_token(), 7)
# Check that the first ID gen advancing causes the second ID gen to
# advance (as the second ID gen has nothing in flight).
async def _get_next_async() -> None:
async with first_id_gen.get_next_mult(2):
pass
self.get_success(_get_next_async())
second_id_gen.advance("first", 9)
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 9)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
self.assertEqual(second_id_gen.get_current_token(), 7)
# Check that the first ID gen advancing doesn't advance the second ID
# gen when the second ID gen has stuff in flight.
self.get_success(_get_next_async())
ctxmgr = second_id_gen.get_next()
self.get_success(ctxmgr.__aenter__())
second_id_gen.advance("first", 11)
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 9)
self.assertEqual(second_id_gen.get_current_token(), 7)
self.get_success(ctxmgr.__aexit__(None, None, None))
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 11)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 12)
self.assertEqual(second_id_gen.get_current_token(), 7)
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.""" """Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
@ -712,8 +824,8 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.get_success(_get_next_async()) self.get_success(_get_next_async())
self.assertEqual(id_gen_1.get_positions(), {"first": -1}) self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -1})
self.assertEqual(id_gen_2.get_positions(), {"first": -1}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -1})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
@ -822,11 +934,11 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen = self._create_id_generator("second", writers=["first", "second"]) second_id_gen = self._create_id_generator("second", writers=["first", "second"])
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6}) self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6) self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3) self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(second_id_gen.get_persisted_upto_position(), 7) self.assertEqual(second_id_gen.get_persisted_upto_position(), 7)

View File

@ -30,6 +30,7 @@ from typing import (
Generic, Generic,
Iterable, Iterable,
List, List,
Mapping,
NoReturn, NoReturn,
Optional, Optional,
Tuple, Tuple,
@ -251,7 +252,7 @@ class TestCase(unittest.TestCase):
except AssertionError as e: except AssertionError as e:
raise (type(e))(f"Assert error for '.{key}':") from e raise (type(e))(f"Assert error for '.{key}':") from e
def assert_dict(self, required: dict, actual: dict) -> None: def assert_dict(self, required: Mapping, actual: Mapping) -> None:
"""Does a partial assert of a dict. """Does a partial assert of a dict.
Args: Args: