Support any process writing to cache invalidation stream. (#7436)

This commit is contained in:
Erik Johnston 2020-05-07 13:51:08 +01:00 committed by GitHub
parent 2929ce29d6
commit d7983b63a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 225 additions and 230 deletions

View file

@ -95,20 +95,25 @@ class Stream(object):
def __init__(
self,
local_instance_name: str,
current_token_function: Callable[[], Token],
current_token_function: Callable[[str], Token],
update_function: UpdateFunction,
):
"""Instantiate a Stream
current_token_function and update_function are callbacks which should be
implemented by subclasses.
`current_token_function` and `update_function` are callbacks which
should be implemented by subclasses.
current_token_function is called to get the current token of the underlying
stream. It is only meaningful on the process that is the source of the
replication stream (ie, usually the master).
`current_token_function` takes an instance name, which is a writer to
the stream, and returns the position in the stream of the writer (as
viewed from the current process). On the writer process this is where
the writer has successfully written up to, whereas on other processes
this is the position which we have received updates up to over
replication. (Note that most streams have a single writer and so their
implementations ignore the instance name passed in).
update_function is called to get updates for this stream between a pair of
stream tokens. See the UpdateFunction type definition for more info.
`update_function` is called to get updates for this stream between a
pair of stream tokens. See the `UpdateFunction` type definition for more
info.
Args:
local_instance_name: The instance name of the current process
@ -120,13 +125,13 @@ class Stream(object):
self.update_function = update_function
# The token from which we last asked for updates
self.last_token = self.current_token()
self.last_token = self.current_token(self.local_instance_name)
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
self.last_token = self.current_token()
self.last_token = self.current_token(self.local_instance_name)
async def get_updates(self) -> StreamUpdateResult:
"""Gets all updates since the last time this function was called (or
@ -138,7 +143,7 @@ class Stream(object):
position in stream, and `limited` is whether there are more updates
to fetch.
"""
current_token = self.current_token()
current_token = self.current_token(self.local_instance_name)
updates, current_token, limited = await self.get_updates_since(
self.local_instance_name, self.last_token, current_token
)
@ -170,6 +175,16 @@ class Stream(object):
return updates, upto_token, limited
def current_token_without_instance(
current_token: Callable[[], int]
) -> Callable[[str], int]:
"""Takes a current token callback function for a single writer stream
that doesn't take an instance name parameter and wraps it in a function that
does accept an instance name parameter but ignores it.
"""
return lambda instance_name: current_token()
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
@ -235,7 +250,7 @@ class BackfillStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_backfill_token,
current_token_without_instance(store.get_current_backfill_token),
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
@ -271,7 +286,9 @@ class PresenceStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
hs.get_instance_name(), store.get_current_presence_token, update_function
hs.get_instance_name(),
current_token_without_instance(store.get_current_presence_token),
update_function,
)
@ -296,7 +313,9 @@ class TypingStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
hs.get_instance_name(), typing_handler.get_current_token, update_function
hs.get_instance_name(),
current_token_without_instance(typing_handler.get_current_token),
update_function,
)
@ -319,7 +338,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_receipt_stream_id,
current_token_without_instance(store.get_max_receipt_stream_id),
db_query_to_update_function(store.get_all_updated_receipts),
)
@ -339,7 +358,7 @@ class PushRulesStream(Stream):
hs.get_instance_name(), self._current_token, self._update_function
)
def _current_token(self) -> int:
def _current_token(self, instance_name: str) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
@ -373,7 +392,7 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
store.get_pushers_stream_token,
current_token_without_instance(store.get_pushers_stream_token),
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
@ -402,13 +421,27 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_cache_stream_token,
db_query_to_update_function(store.get_all_updated_caches),
self.store.get_cache_stream_token,
self._update_function,
)
async def _update_function(
self, instance_name: str, from_token: int, upto_token: int, limit: int
):
rows = await self.store.get_all_updated_caches(
instance_name, from_token, upto_token, limit
)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
class PublicRoomsStream(Stream):
"""The public rooms list changed
@ -431,7 +464,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_public_room_stream_id,
current_token_without_instance(store.get_current_public_room_stream_id),
db_query_to_update_function(store.get_all_new_public_rooms),
)
@ -452,7 +485,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
@ -470,7 +503,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_to_device_stream_token,
current_token_without_instance(store.get_to_device_stream_token),
db_query_to_update_function(store.get_all_new_device_messages),
)
@ -490,7 +523,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_account_data_stream_id,
current_token_without_instance(store.get_max_account_data_stream_id),
db_query_to_update_function(store.get_all_updated_tags),
)
@ -510,7 +543,7 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
self.store.get_max_account_data_stream_id,
current_token_without_instance(self.store.get_max_account_data_stream_id),
db_query_to_update_function(self._update_function),
)
@ -541,7 +574,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_group_stream_token,
current_token_without_instance(store.get_group_stream_token),
db_query_to_update_function(store.get_all_groups_changes),
)
@ -559,7 +592,7 @@ class UserSignatureStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
),