Use async with for ID gens (#8383)

This will allow us to hit the DB after we've finished using the generated stream ID.
This commit is contained in:
Erik Johnston 2020-09-23 16:11:18 +01:00 committed by GitHub
parent 916bb9d0d1
commit cbabb312e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 144 additions and 105 deletions

1
changelog.d/8383.misc Normal file
View File

@ -0,0 +1 @@
Refactor ID generators to use `async with` syntax.

View File

@ -339,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore):
""" """
content_json = json_encoder.encode(content) content_json = json_encoder.encode(content)
with await self._account_data_id_gen.get_next() as next_id: async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint # no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will # on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict. # retry if there is a conflict.
@ -387,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore):
""" """
content_json = json_encoder.encode(content) content_json = json_encoder.encode(content)
with await self._account_data_id_gen.get_next() as next_id: async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on # no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if # (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict. # there is a conflict.

View File

@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((destination, stream_id, now_ms, edu_json)) rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows) txn.executemany(sql, rows)
with await self._device_inbox_id_gen.get_next() as stream_id: async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec() now_ms = self.clock.time_msec()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device txn, stream_id, local_messages_by_user_then_device
) )
with await self._device_inbox_id_gen.get_next() as stream_id: async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec() now_ms = self.clock.time_msec()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox", "add_messages_from_remote_to_device_inbox",

View File

@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID. THe new stream ID.
""" """
with await self._device_list_id_gen.get_next() as stream_id: async with self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_user_sig_change_to_streams", "add_user_sig_change_to_streams",
self._add_user_signature_change_txn, self._add_user_signature_change_txn,
@ -1093,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids: if not device_ids:
return return
with await self._device_list_id_gen.get_next_mult( async with self._device_list_id_gen.get_next_mult(
len(device_ids) len(device_ids)
) as stream_ids: ) as stream_ids:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@ -1108,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1] return stream_ids[-1]
context = get_active_span_text_map() context = get_active_span_text_map()
with await self._device_list_id_gen.get_next_mult( async with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids) len(hosts) * len(device_ids)
) as stream_ids: ) as stream_ids:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(

View File

@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key (dict): the key data key (dict): the key data
""" """
with await self._cross_signing_id_gen.get_next() as stream_id: async with self._cross_signing_id_gen.get_next() as stream_id:
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"add_e2e_cross_signing_key", "add_e2e_cross_signing_key",
self._set_e2e_cross_signing_key_txn, self._set_e2e_cross_signing_key_txn,

View File

@ -156,15 +156,15 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at # Note: Multiple instances of this function cannot be in flight at
# the same time for the same room. # the same time for the same room.
if backfilled: if backfilled:
stream_ordering_manager = await self._backfill_id_gen.get_next_mult( stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts) len(events_and_contexts)
) )
else: else:
stream_ordering_manager = await self._stream_id_gen.get_next_mult( stream_ordering_manager = self._stream_id_gen.get_next_mult(
len(events_and_contexts) len(events_and_contexts)
) )
with stream_ordering_manager as stream_orderings: async with stream_ordering_manager as stream_orderings:
for (event, context), stream in zip(events_and_contexts, stream_orderings): for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream event.internal_metadata.stream_ordering = stream

View File

@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id return next_id
with await self._group_updates_id_gen.get_next() as next_id: async with self._group_updates_id_gen.get_next() as next_id:
res = await self.db_pool.runInteraction( res = await self.db_pool.runInteraction(
"register_user_group_membership", "register_user_group_membership",
_register_user_group_membership_txn, _register_user_group_membership_txn,

View File

@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore): class PresenceStore(SQLBaseStore):
async def update_presence(self, presence_states): async def update_presence(self, presence_states):
stream_ordering_manager = await self._presence_id_gen.get_next_mult( stream_ordering_manager = self._presence_id_gen.get_next_mult(
len(presence_states) len(presence_states)
) )
with stream_ordering_manager as stream_orderings: async with stream_ordering_manager as stream_orderings:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"update_presence", "update_presence",
self._update_presence_txn, self._update_presence_txn,

View File

@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None: ) -> None:
conditions_json = json_encoder.encode(conditions) conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions) actions_json = json_encoder.encode(actions)
with await self._push_rules_stream_id_gen.get_next() as stream_id: async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token() event_stream_ordering = self._stream_id_gen.get_current_token()
if before or after: if before or after:
@ -585,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE" txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
) )
with await self._push_rules_stream_id_gen.get_next() as stream_id: async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token() event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@ -616,7 +616,7 @@ class PushRuleStore(PushRulesWorkerStore):
Raises: Raises:
NotFoundError if the rule does not exist. NotFoundError if the rule does not exist.
""" """
with await self._push_rules_stream_id_gen.get_next() as stream_id: async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token() event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn", "_set_push_rule_enabled_txn",
@ -754,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json}, data={"actions": actions_json},
) )
with await self._push_rules_stream_id_gen.get_next() as stream_id: async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token() event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(

View File

@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
last_stream_ordering, last_stream_ordering,
profile_tag="", profile_tag="",
) -> None: ) -> None:
with await self._pushers_id_gen.get_next() as stream_id: async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on # no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry # (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
}, },
) )
with await self._pushers_id_gen.get_next() as stream_id: async with self._pushers_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id "delete_pusher", delete_pusher_txn, stream_id
) )

View File

@ -524,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"insert_receipt_conv", graph_to_linear "insert_receipt_conv", graph_to_linear
) )
with await self._receipts_id_gen.get_next() as stream_id: async with self._receipts_id_gen.get_next() as stream_id:
event_ts = await self.db_pool.runInteraction( event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt", "insert_linearized_receipt",
self.insert_linearized_receipt_txn, self.insert_linearized_receipt_txn,

View File

@ -1137,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
}, },
) )
with await self._public_room_id_gen.get_next() as next_id: async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id "store_room_txn", store_room_txn, next_id
) )
@ -1204,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
}, },
) )
with await self._public_room_id_gen.get_next() as next_id: async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id "set_room_is_public", set_room_is_public_txn, next_id
) )
@ -1284,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
}, },
) )
with await self._public_room_id_gen.get_next() as next_id: async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"set_room_is_public_appservice", "set_room_is_public_appservice",
set_room_is_public_appservice_txn, set_room_is_public_appservice_txn,

View File

@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
) )
self._update_revision_txn(txn, user_id, room_id, next_id) self._update_revision_txn(txn, user_id, room_id, next_id)
with await self._account_data_id_gen.get_next() as next_id: async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag)) txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id) self._update_revision_txn(txn, user_id, room_id, next_id)
with await self._account_data_id_gen.get_next() as next_id: async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))

View File

@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import heapq import heapq
import logging import logging
import threading import threading
from collections import deque from collections import deque
from typing import Dict, List, Set from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Union
import attr
from typing_extensions import Deque from typing_extensions import Deque
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import DatabasePool, LoggingTransaction
@ -86,7 +86,7 @@ class StreamIdGenerator:
upwards, -1 to grow downwards. upwards, -1 to grow downwards.
Usage: Usage:
with await stream_id_gen.get_next() as stream_id: async with stream_id_gen.get_next() as stream_id:
# ... persist event ... # ... persist event ...
""" """
@ -101,10 +101,10 @@ class StreamIdGenerator:
) )
self._unfinished_ids = deque() # type: Deque[int] self._unfinished_ids = deque() # type: Deque[int]
async def get_next(self): def get_next(self):
""" """
Usage: Usage:
with await stream_id_gen.get_next() as stream_id: async with stream_id_gen.get_next() as stream_id:
# ... persist event ... # ... persist event ...
""" """
with self._lock: with self._lock:
@ -113,7 +113,7 @@ class StreamIdGenerator:
self._unfinished_ids.append(next_id) self._unfinished_ids.append(next_id)
@contextlib.contextmanager @contextmanager
def manager(): def manager():
try: try:
yield next_id yield next_id
@ -121,12 +121,12 @@ class StreamIdGenerator:
with self._lock: with self._lock:
self._unfinished_ids.remove(next_id) self._unfinished_ids.remove(next_id)
return manager() return _AsyncCtxManagerWrapper(manager())
async def get_next_mult(self, n): def get_next_mult(self, n):
""" """
Usage: Usage:
with await stream_id_gen.get_next(n) as stream_ids: async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ... # ... persist events ...
""" """
with self._lock: with self._lock:
@ -140,7 +140,7 @@ class StreamIdGenerator:
for next_id in next_ids: for next_id in next_ids:
self._unfinished_ids.append(next_id) self._unfinished_ids.append(next_id)
@contextlib.contextmanager @contextmanager
def manager(): def manager():
try: try:
yield next_ids yield next_ids
@ -149,7 +149,7 @@ class StreamIdGenerator:
for next_id in next_ids: for next_id in next_ids:
self._unfinished_ids.remove(next_id) self._unfinished_ids.remove(next_id)
return manager() return _AsyncCtxManagerWrapper(manager())
def get_current_token(self): def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
@ -282,59 +282,23 @@ class MultiWriterIdGenerator:
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n) return self._sequence_gen.get_next_mult_txn(txn, n)
async def get_next(self): def get_next(self):
""" """
Usage: Usage:
with await stream_id_gen.get_next() as stream_id: async with stream_id_gen.get_next() as stream_id:
# ... persist event ... # ... persist event ...
""" """
next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
# Assert the fetched ID is actually greater than what we currently return _MultiWriterCtxManager(self)
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
with self._lock:
assert self._current_positions.get(self._instance_name, 0) < next_id
self._unfinished_ids.add(next_id) def get_next_mult(self, n: int):
@contextlib.contextmanager
def manager():
try:
# Multiply by the return factor so that the ID has correct sign.
yield self._return_factor * next_id
finally:
self._mark_id_as_finished(next_id)
return manager()
async def get_next_mult(self, n: int):
""" """
Usage: Usage:
with await stream_id_gen.get_next_mult(5) as stream_ids: async with stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ... # ... persist events ...
""" """
next_ids = await self._db.runInteraction(
"_load_next_mult_id", self._load_next_mult_id_txn, n
)
# Assert the fetched ID is actually greater than any ID we've already return _MultiWriterCtxManager(self, n)
# seen. If not, then the sequence and table have got out of sync
# somehow.
with self._lock:
assert max(self._current_positions.values(), default=0) < min(next_ids)
self._unfinished_ids.update(next_ids)
@contextlib.contextmanager
def manager():
try:
yield [self._return_factor * i for i in next_ids]
finally:
for i in next_ids:
self._mark_id_as_finished(i)
return manager()
def get_next_txn(self, txn: LoggingTransaction): def get_next_txn(self, txn: LoggingTransaction):
""" """
@ -482,3 +446,61 @@ class MultiWriterIdGenerator:
# There was a gap in seen positions, so there is nothing more to # There was a gap in seen positions, so there is nothing more to
# do. # do.
break break
@attr.s(slots=True)
class _AsyncCtxManagerWrapper:
"""Helper class to convert a plain context manager to an async one.
This is mainly useful if you have a plain context manager but the interface
requires an async one.
"""
inner = attr.ib()
async def __aenter__(self):
return self.inner.__enter__()
async def __aexit__(self, exc_type, exc, tb):
return self.inner.__exit__(exc_type, exc, tb)
@attr.s(slots=True)
class _MultiWriterCtxManager:
"""Async context manager returned by MultiWriterIdGenerator
"""
id_gen = attr.ib(type=MultiWriterIdGenerator)
multiple_ids = attr.ib(type=Optional[int], default=None)
stream_ids = attr.ib(type=List[int], factory=list)
async def __aenter__(self) -> Union[int, List[int]]:
self.stream_ids = await self.id_gen._db.runInteraction(
"_load_next_mult_id",
self.id_gen._load_next_mult_id_txn,
self.multiple_ids or 1,
)
# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
with self.id_gen._lock:
assert max(self.id_gen._current_positions.values(), default=0) < min(
self.stream_ids
)
self.id_gen._unfinished_ids.update(self.stream_ids)
if self.multiple_ids is None:
return self.stream_ids[0] * self.id_gen._return_factor
else:
return [i * self.id_gen._return_factor for i in self.stream_ids]
async def __aexit__(self, exc_type, exc, tb):
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)
if exc_type is not None:
return False
return False

View File

@ -111,7 +111,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# advanced after we leave the context manager. # advanced after we leave the context manager.
async def _get_next_async(): async def _get_next_async():
with await id_gen.get_next() as stream_id: async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8) self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_positions(), {"master": 7})
@ -139,10 +139,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
ctx3 = self.get_success(id_gen.get_next()) ctx3 = self.get_success(id_gen.get_next())
ctx4 = self.get_success(id_gen.get_next()) ctx4 = self.get_success(id_gen.get_next())
s1 = ctx1.__enter__() s1 = self.get_success(ctx1.__aenter__())
s2 = ctx2.__enter__() s2 = self.get_success(ctx2.__aenter__())
s3 = ctx3.__enter__() s3 = self.get_success(ctx3.__aenter__())
s4 = ctx4.__enter__() s4 = self.get_success(ctx4.__aenter__())
self.assertEqual(s1, 8) self.assertEqual(s1, 8)
self.assertEqual(s2, 9) self.assertEqual(s2, 9)
@ -152,22 +152,22 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
ctx2.__exit__(None, None, None) self.get_success(ctx2.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
ctx1.__exit__(None, None, None) self.get_success(ctx1.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 9}) self.assertEqual(id_gen.get_positions(), {"master": 9})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
ctx4.__exit__(None, None, None) self.get_success(ctx4.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 9}) self.assertEqual(id_gen.get_positions(), {"master": 9})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 9) self.assertEqual(id_gen.get_current_token_for_writer("master"), 9)
ctx3.__exit__(None, None, None) self.get_success(ctx3.__aexit__(None, None, None))
self.assertEqual(id_gen.get_positions(), {"master": 11}) self.assertEqual(id_gen.get_positions(), {"master": 11})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 11) self.assertEqual(id_gen.get_current_token_for_writer("master"), 11)
@ -190,7 +190,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# advanced after we leave the context manager. # advanced after we leave the context manager.
async def _get_next_async(): async def _get_next_async():
with await first_id_gen.get_next() as stream_id: async with first_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8) self.assertEqual(stream_id, 8)
self.assertEqual( self.assertEqual(
@ -208,7 +208,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# stream ID # stream ID
async def _get_next_async(): async def _get_next_async():
with await second_id_gen.get_next() as stream_id: async with second_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 9) self.assertEqual(stream_id, 9)
self.assertEqual( self.assertEqual(
@ -305,10 +305,14 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5}) self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 3) self.assertEqual(id_gen.get_persisted_upto_position(), 3)
with self.get_success(id_gen.get_next()) as stream_id:
async def _get_next_async():
async with id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 6) self.assertEqual(stream_id, 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 3) self.assertEqual(id_gen.get_persisted_upto_position(), 3)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_persisted_upto_position(), 6) self.assertEqual(id_gen.get_persisted_upto_position(), 6)
# We assume that so long as `get_next` does correctly advance the # We assume that so long as `get_next` does correctly advance the
@ -373,17 +377,23 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
""" """
id_gen = self._create_id_generator() id_gen = self._create_id_generator()
with self.get_success(id_gen.get_next()) as stream_id: async def _get_next_async():
async with id_gen.get_next() as stream_id:
self._insert_row("master", stream_id) self._insert_row("master", stream_id)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": -1}) self.assertEqual(id_gen.get_positions(), {"master": -1})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
self.assertEqual(id_gen.get_persisted_upto_position(), -1) self.assertEqual(id_gen.get_persisted_upto_position(), -1)
with self.get_success(id_gen.get_next_mult(3)) as stream_ids: async def _get_next_async2():
async with id_gen.get_next_mult(3) as stream_ids:
for stream_id in stream_ids: for stream_id in stream_ids:
self._insert_row("master", stream_id) self._insert_row("master", stream_id)
self.get_success(_get_next_async2())
self.assertEqual(id_gen.get_positions(), {"master": -4}) self.assertEqual(id_gen.get_positions(), {"master": -4})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4) self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
self.assertEqual(id_gen.get_persisted_upto_position(), -4) self.assertEqual(id_gen.get_persisted_upto_position(), -4)
@ -402,19 +412,25 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen_1 = self._create_id_generator("first") id_gen_1 = self._create_id_generator("first")
id_gen_2 = self._create_id_generator("second") id_gen_2 = self._create_id_generator("second")
with self.get_success(id_gen_1.get_next()) as stream_id: async def _get_next_async():
async with id_gen_1.get_next() as stream_id:
self._insert_row("first", stream_id) self._insert_row("first", stream_id)
id_gen_2.advance("first", stream_id) id_gen_2.advance("first", stream_id)
self.get_success(_get_next_async())
self.assertEqual(id_gen_1.get_positions(), {"first": -1}) self.assertEqual(id_gen_1.get_positions(), {"first": -1})
self.assertEqual(id_gen_2.get_positions(), {"first": -1}) self.assertEqual(id_gen_2.get_positions(), {"first": -1})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
with self.get_success(id_gen_2.get_next()) as stream_id: async def _get_next_async2():
async with id_gen_2.get_next() as stream_id:
self._insert_row("second", stream_id) self._insert_row("second", stream_id)
id_gen_1.advance("second", stream_id) id_gen_1.advance("second", stream_id)
self.get_success(_get_next_async2())
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)